diff --git a/vllm_ascend/models/qwen3_next.py b/vllm_ascend/models/qwen3_next.py index b4e71bc..7d1481e 100644 --- a/vllm_ascend/models/qwen3_next.py +++ b/vllm_ascend/models/qwen3_next.py @@ -6,7 +6,6 @@ from collections.abc import Iterable from typing import Optional import torch -import torch.nn.functional as F from einops import rearrange from torch import nn from transformers.activations import ACT2FN @@ -19,6 +18,10 @@ from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.forward_context import ForwardContext, get_forward_context +from vllm.model_executor.layers.fla.ops import RMSNormGated +from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule +from vllm.model_executor.layers.fla.ops.fused_recurrent import \ + fused_recurrent_gated_delta_rule from vllm.model_executor.layers.fused_moe import FusedMoE # yapf conflicts with isort for this block # yapf: disable @@ -34,6 +37,8 @@ from vllm.model_executor.layers.mamba.mamba_mixer2 import \ mamba_v2_sharded_weight_loader from vllm.model_executor.layers.mamba.mamba_utils import ( MambaStateDtypeCalculator, MambaStateShapeCalculator) +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) @@ -45,7 +50,8 @@ from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid, from vllm.model_executor.models.mamba_cache import MambaCacheParams from vllm.model_executor.models.qwen2_moe import Qwen2MoeMLP as Qwen3NextMLP from vllm.model_executor.models.qwen3_next import (Qwen3NextAttention, - Qwen3NextSparseMoeBlock) + Qwen3NextSparseMoeBlock, + fused_gdn_gating) from vllm.model_executor.models.utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, @@ -57,108 +63,6 @@ from vllm.transformers_utils.configs import Qwen3NextConfig from vllm.utils import direct_register_custom_op from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata -from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, - causal_conv1d_update_npu) -from vllm_ascend.ops.fla import RMSNormGated, fused_gdn_gating -from vllm_ascend.ops.sigmoid_gating import fused_recurrent_gated_delta_rule - - -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, -): - initial_dtype = query.dtype - if use_qk_l2norm_in_kernel: - query = F.normalize(query, p=2, dim=-1) - key = F.normalize(key, p=2, dim=-1) - query, key, value, beta, g = [ - x.transpose(1, 2).contiguous().to(torch.float32) - for x in (query, key, value, beta, g) - ] - - batch_size, sequence_length, num_heads, k_head_dim = key.shape - v_head_dim = value.shape[-1] - pad_size = (chunk_size - num_heads % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) - tot_heads = num_heads + pad_size - scale = 1 / (query.shape[-1]**0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = [ - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) - for x in (query, key, value, k_beta, v_beta) - ] - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -( - (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - - last_recurrent_state = (torch.zeros(batch_size, sequence_length, - k_head_dim, v_head_dim).to(value) if - initial_state is None else initial_state.to(value)) - - core_attn_out = torch.zeros_like(value) - mask = torch.triu(torch.ones(chunk_size, - chunk_size, - dtype=torch.bool, - device=query.device), - diagonal=1) - - # for each chunk - for i in range(0, tot_heads // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * - decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() + - (k_i * - (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( - -1, -2) @ v_new) - - if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], - core_attn_out.shape[1], -1, - core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :num_heads] - core_attn_out = core_attn_out.transpose(1, - 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state - class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): @@ -275,6 +179,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): self.norm = RMSNormGated( self.head_v_dim, eps=self.layer_norm_epsilon, + norm_before_gate=True, + device="npu", ) self.out_proj = RowParallelLinear(self.value_dim, @@ -467,7 +373,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): query_start_loc=non_spec_query_start_loc, ).transpose(0, 1) elif attn_metadata.num_decodes > 0: - mixed_qkv_non_spec = causal_conv1d_update_npu( + mixed_qkv_non_spec = causal_conv1d_update( mixed_qkv_non_spec, conv_state, conv_weights, @@ -551,7 +457,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ( cur_core_attn_out_non_spec, cur_last_recurrent_state, - ) = torch_chunk_gated_delta_rule( + ) = chunk_gated_delta_rule( query=cur_q, key=cur_k, value=cur_v, diff --git a/vllm_ascend/ops/casual_conv1d.py b/vllm_ascend/ops/casual_conv1d.py index 68790b5..2d00889 100644 --- a/vllm_ascend/ops/casual_conv1d.py +++ b/vllm_ascend/ops/casual_conv1d.py @@ -1,597 +1,539 @@ -# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -# SPDX-License-Identifier: Apache-2.0 - -# Copyright (c) 2024, Tri Dao. -# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py -# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -# mypy: ignore-errors - -from typing import Optional, Union - -import torch -import torch.nn.functional as F -import triton -import triton.language as tl - -PAD_SLOT_ID = -1 - - -def causal_conv1d_ref( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", -): - """ - x: (batch, dim, seqlen) - weight: (dim, width) - bias: (dim,) - initial_states: (batch, dim, width - 1) - final_states_out: (batch, dim, width - 1) - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - x = x.to(weight.dtype) - seqlen = x.shape[-1] - dim, width = weight.shape - - if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) - else: - x = torch.cat([initial_states, x], dim=-1) - out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) - out = out[..., :seqlen] - if return_final_states: - final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( - dtype_in) # (batch, dim, width - 1) - if final_states_out is not None: - final_states_out.copy_(final_states) - else: - final_states_out = final_states - out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) - return (out, None) if not return_final_states else (out, final_states_out) - - -def causal_conv1d_fn( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - cache_indices: Optional[torch.Tensor] = None, - has_initial_state: Optional[torch.Tensor] = None, - conv_states: Optional[torch.Tensor] = None, - activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, -): - """ - x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen - sequences are concatenated from left to right for varlen - weight: (dim, width) - bias: (dim,) - query_start_loc: (batch + 1) int32 - The cumulative sequence lengths of the sequences in - the batch, used to index into sequence. prepended by 0. - for example: query_start_loc = torch.Tensor([0,10,16,17]), - x.shape=(dim,17) - cache_indices: (batch) int32 - indicates the corresponding state index, - like so: conv_state = conv_states[cache_indices[batch_id]] - has_initial_state: (batch) bool - indicates whether should the kernel take the current state as initial - state for the calculations - conv_states: (...,dim,width - 1) itype - updated inplace if provided - activation: either None or "silu" or "swish" - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - - - out: (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(-1) != 1: - x = x.contiguous() - bias = bias.contiguous() if bias is not None else None - - out_ref = [] - out_ref_b = [] - seqlens = query_start_loc[1:] - query_start_loc[:-1] - seqlens = seqlens.tolist() - splits = torch.split(x, seqlens, dim=-1) - - for i in range(len(seqlens)): - x_s = splits[i] - if cache_indices[i] == PAD_SLOT_ID: - continue - out_ref_b.append( - causal_conv1d_ref( - x_s, - weight, - bias, - activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]].unsqueeze(0), - initial_states=conv_states[cache_indices[i]] - if has_initial_state[i] else None)) - out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) - out_ref_tensor = torch.cat(out_ref, dim=0) - return out_ref_tensor - - -def causal_conv1d_update_ref(x, - conv_state, - weight, - bias=None, - activation=None, - cache_seqlens=None, - conv_state_indices=None): - """ - x: (batch, dim) or (batch, dim, seqlen) - conv_state: (batch, dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the - conv_state starting at the index - @cache_seqlens % state_len before performing the convolution. - - out: (batch, dim) or (batch, dim, seqlen) - """ - if activation not in [None, "silu", "swish"]: - raise NotImplementedError("activation must be None, silu, or swish") - dtype_in = x.dtype - unsqueeze = x.dim() == 2 - if unsqueeze: - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - width = weight.shape[1] - state_len = conv_state.shape[-1] - assert weight.shape == (dim, width) - if cache_seqlens is None: - x_new = torch.cat([conv_state[conv_state_indices], x], dim=-1).to( - weight.dtype) # (batch, dim, state_len + seqlen) - conv_state[conv_state_indices] = x_new[:, :, -state_len:] - else: - width_idx = torch.arange( - -(width - 1), 0, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - width_idx = (torch.remainder(width_idx, state_len).unsqueeze(1).expand( - -1, dim, -1)) - x_new = torch.cat([conv_state.gather(2, width_idx), x], - dim=-1).to(weight.dtype) - copy_idx = torch.arange( - seqlen, dtype=torch.long, - device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) - copy_idx = torch.remainder(copy_idx, - state_len).unsqueeze(1).expand(-1, dim, -1) - conv_state.scatter_(2, copy_idx, x) - out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, - groups=dim)[:, :, -seqlen:] - if unsqueeze: - out = out.squeeze(-1) - return (out if activation is None else F.silu(out)).to(dtype=dtype_in) - - -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer - conv_state_indices_ptr, - num_accepted_tokens_ptr, - intermediate_conv_window_ptr, - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_inter_seq: tl.constexpr, - stride_inter_step: tl.constexpr, - stride_inter_dim: tl.constexpr, - stride_inter_win: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, - SAVE_INTERMEDIATE: tl.constexpr, -): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return - - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + - idx_seq * stride_state_indices).to( - tl.int64) - else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + - idx_seq) - 1 - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # The conv_state updates works in a sliding window manner, - # at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + - conv_state_token_offset * stride_conv_state_tok + - (idx_feats * stride_conv_state_dim)[None, :] + - ((idx_tokens + 1) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ((conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :]) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim - ) # [BLOCK_N] - - x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] - ) # [BLOCK_M, BLOCK_N] - - mask_x = ((idx_tokens - VAL >= 0)[:, None] - & (idx_tokens - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = (conv_state_ptr + - (conv_state_batch_coord * stride_conv_state_seq) + - (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] - conv_state_ptrs_target = (conv_state_base + - (idx_tokens * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, - other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.static_range(seqlen): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - # mask_1d = (idx_token < seqlen) & ( - # idx_feats < dim - # ) # token-index # feature-index - maskL = idx_feats < dim - maskR = tl.full(maskL.shape, False, tl.int1) - mask_1d = tl.where(idx_token < seqlen, maskL, maskR) - - o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + - idx_token * stride_o_token + (idx_feats * stride_o_dim)) - - tl.store(o_ptrs, acc, mask=mask_1d) - - if SAVE_INTERMEDIATE: - # Save the window state after consuming this token - # Layout: [seq(cache line), step, dim, win(K-1)] - base_ptr = (intermediate_conv_window_ptr + - conv_state_batch_coord * stride_inter_seq + - idx_token * stride_inter_step + - idx_feats * stride_inter_dim) - if KERNEL_WIDTH >= 2: - tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) - if KERNEL_WIDTH >= 3: - tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) - if KERNEL_WIDTH >= 4: - tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) - - -def causal_conv1d_update_npu( - x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, - cache_seqlens: Optional[torch.Tensor] = None, - conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - intermediate_conv_window: Optional[torch.Tensor] = None, - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=False, -): - """ - x: (batch, dim) or (batch, dim, seqlen) - [shape=2: single token prediction] - [shape=3: single or multiple tokens prediction] - conv_state: (..., dim, state_len), where state_len >= width - 1 - weight: (dim, width) - bias: (dim,) - cache_seqlens: (batch,), dtype int32. - If not None, the conv_state is treated as a circular buffer. - The conv_state will be updated by copying x to the conv_state - starting at the index - @cache_seqlens % state_len. - conv_state_indices: (batch,), dtype int32 - If not None, the conv_state is a larger tensor along the batch dim, - and we are selecting the batch coords specified by conv_state_indices. - Useful for a continuous batching scenario. - pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) - """ - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - unsqueeze = x.dim() == 2 - if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 - x = x.unsqueeze(-1) - batch, dim, seqlen = x.shape - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch, ) == conv_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - stride_x_seq, stride_x_dim, stride_x_token = x.stride( - ) # X (batch, dim, seqlen) - - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( - ) - stride_state_indices = (conv_state_indices.stride(0) - if conv_state_indices is not None else 0) - state_len = width - 1 + (seqlen - 1) # effective state_len needed - np2_statelen = triton.next_power_of_2(state_len) - - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - # prepare intermediate buffer strides if provided - if intermediate_conv_window is not None: - stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( - intermediate_conv_window.stride(0), - intermediate_conv_window.stride(1), - intermediate_conv_window.stride(2), - intermediate_conv_window.stride(3), - ) - else: - stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices - x, - weight, - bias, - conv_state, - cache_seqlens, - conv_state_indices, - num_accepted_tokens, - intermediate_conv_window - if intermediate_conv_window is not None else x, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_inter_seq, - stride_inter_step, - stride_inter_dim, - stride_inter_win, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others - pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=128, - SAVE_INTERMEDIATE=intermediate_conv_window is not None, - ) - if unsqueeze: - out = out.squeeze(-1) - return out +# adapted from vllm/model_executor/layers/mamba/ops/casual_conv1d.py +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# mypy: ignore-errors + +from typing import Optional, Union + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + out_ref = [] + out_ref_b = [] + seqlens = query_start_loc[1:] - query_start_loc[:-1] + seqlens = seqlens.tolist() + splits = torch.split(x, seqlens, dim=-1) + + for i in range(len(seqlens)): + x_s = splits[i] + if cache_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[i]].unsqueeze(0), + initial_states=conv_states[cache_indices[i]] + if has_initial_state[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) + out_ref_tensor = torch.cat(out_ref, dim=0) + return out_ref_tensor + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + + idx_seq * stride_state_indices).to( + tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + #col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ((conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :]) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim + ) # [BLOCK_N] + + x_ptrs = (x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ((idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = (conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)) # [BLOCK_N,] + conv_state_ptrs_target = (conv_state_base + + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, + other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N, ), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + # mask_1d = (idx_token < seqlen) & ( + # idx_feats < dim + # ) # token-index # feature-index + maskL = idx_feats < dim + maskR = tl.full(maskL.shape, False, tl.int1) + mask_1d = tl.where(idx_token < seqlen, maskL, maskR) + + o_ptrs = (o_ptr + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + (idx_feats * stride_o_dim)) + + tl.store(o_ptrs, acc, mask=mask_1d) + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = (intermediate_conv_window_ptr + + conv_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + +def causal_conv1d_update_npu( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch, ) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride( + ) # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride( + ) + stride_state_indices = (conv_state_indices.stride(0) + if conv_state_indices is not None else 0) + state_len = width - 1 + (seqlen - 1) # effective state_len needed + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window + if intermediate_conv_window is not None else x, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=128, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/vllm_ascend/ops/fla.py b/vllm_ascend/ops/fla.py index df81dd5..b200c67 100644 --- a/vllm_ascend/ops/fla.py +++ b/vllm_ascend/ops/fla.py @@ -9,109 +9,8 @@ import torch import torch.nn.functional as F import triton -import triton.language as tl -from einops import rearrange - - -def rms_norm_ref( - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - upcast=True, -): - dtype = x.dtype - #N = x.shape[-1] - weight = weight.float() - bias = bias.float() if bias is not None else None - if upcast: - x = x.float() - z = z.float() if z is not None else z - if z is not None and not norm_before_gate: - x = x * F.silu(z) - if group_size is None: - rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps) - out = (x * rstd * weight) + bias if bias is not None else (x * rstd * - weight) - else: - x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) - rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + - eps) - out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight - if bias is not None: - out = out + bias - if z is not None and norm_before_gate: - out *= F.silu(z) - return out.to(dtype) - - -@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) -@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) -@triton.jit -def _layer_norm_fwd_1pass_kernel( - X, # pointer to the input - Y, # pointer to the output - W, # pointer to the weights - B, # pointer to the biases - Z, # pointer to the other branch - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - stride_x_row, # how much to increase the pointer when moving by 1 row - stride_y_row, - stride_z_row, - M, # number of rows in X - N, # number of columns in X - eps, # epsilon to avoid division by zero - BLOCK_N: tl.constexpr, - HAS_BIAS: tl.constexpr, - HAS_Z: tl.constexpr, - NORM_BEFORE_GATE: tl.constexpr, - IS_RMS_NORM: tl.constexpr, -): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) - group = tl.program_id(1) - X += row * stride_x_row + group * N - Y += row * stride_y_row + group * N - if HAS_Z: - Z += row * stride_z_row + group * N - if not IS_RMS_NORM: - Mean += group * M - Rstd += group * M - W += group * N - if HAS_BIAS: - B += group * N - # Compute mean and variance - cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) - if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=cols < N).to(tl.float32) - x *= z * tl.sigmoid(z) - if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) - if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w - if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask).to(tl.float32) - y *= z * tl.sigmoid(z) - # Write output - tl.store(Y + cols, y, mask=mask) +from vllm.model_executor.layers.fla.ops.layernorm_guard import \ + layer_norm_fwd_kernel def _layer_norm_fwd( @@ -158,7 +57,7 @@ def _layer_norm_fwd( num_warps = min(max(BLOCK_N // 256, 1), 8) grid = (M, ngroups) with torch.npu.device(x.device.index): - _layer_norm_fwd_1pass_kernel[grid]( + layer_norm_fwd_kernel[grid]( x, out, weight, @@ -222,160 +121,98 @@ class LayerNormFn(torch.autograd.Function): return y.reshape(x_shape_og) -def layernorm_fn( - x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True, - is_rms_norm=False, -): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, is_rms_norm) - - -def rmsnorm_fn(x, - weight, - bias, - z=None, - eps=1e-6, - group_size=None, - norm_before_gate=True): - return LayerNormFn.apply(x, weight, bias, z, eps, group_size, - norm_before_gate, True) - - -class LayerNorm(torch.nn.Module): - - def __init__( - self, - hidden_size, - eps=1e-5, - group_size=None, - norm_before_gate=True, - device=None, - dtype=None, - ): - """If group_size is not None, we do GroupNorm with each group having group_size elements. - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). - """ - - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter( - torch.empty(hidden_size, **factory_kwargs)) - self.bias = torch.nn.Parameter( - torch.empty(hidden_size, **factory_kwargs)) - self.group_size = group_size - self.norm_before_gate = norm_before_gate - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) - torch.nn.init.zeros_(self.bias) - - def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return layernorm_fn( - x, - self.weight, - self.bias, - z=z, - group_size=self.group_size, - eps=self.eps, - norm_before_gate=self.norm_before_gate, - ) - - -class RMSNormGated(torch.nn.Module): - - def __init__( - self, - hidden_size, - eps=1e-5, - group_size=None, - norm_before_gate=True, - device=None, - dtype=None, - ): - """If group_size is not None, we do GroupNorm with each group having group_size elements. - group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group). - """ - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter( - torch.empty(hidden_size, **factory_kwargs)) - self.register_parameter("bias", None) - self.group_size = group_size - self.norm_before_gate = norm_before_gate - self.reset_parameters() - - def reset_parameters(self): - torch.nn.init.ones_(self.weight) - - def forward(self, x, z=None): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return rmsnorm_fn( - x, - self.weight, - self.bias, - z=z, - eps=self.eps, - group_size=self.group_size, - norm_before_gate=self.norm_before_gate, - ) - - -@triton.jit -def fused_gdn_gating_kernel( +def torch_chunk_gated_delta_rule( + query, + key, + value, g, - A_log, - a, - dt_bias, - seq_len, - NUM_HEADS: tl.constexpr, - beta: tl.constexpr, - threshold: tl.constexpr, - BLK_HEADS: tl.constexpr, + beta, + chunk_size=64, + initial_state=None, + output_final_state=False, + use_qk_l2norm_in_kernel=False, ): - i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) - head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) - off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off - mask = head_off < NUM_HEADS - blk_A_log = tl.load(A_log + head_off, mask=mask) - blk_a = tl.load(a + off, mask=mask) - blk_bias = tl.load(dt_bias + head_off, mask=mask) - # If the model is loaded in fp16, without the .float() here, A might be -inf - x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) - softplus_x = tl.where(beta * x <= threshold, - (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) - blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x - tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + initial_dtype = query.dtype + if use_qk_l2norm_in_kernel: + query = F.normalize(query, p=2, dim=-1) + key = F.normalize(key, p=2, dim=-1) + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + batch_size, sequence_length, num_heads, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - num_heads % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + tot_heads = num_heads + pad_size + scale = 1 / (query.shape[-1]**0.5) + query = query * scale -def fused_gdn_gating( - A_log: torch.Tensor, - a: torch.Tensor, - dt_bias: torch.Tensor, - beta: float = 1.0, - threshold: float = 20.0, -) -> torch.Tensor: - batch, num_heads = a.shape - seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) - g = torch.empty_like(a, dtype=torch.float32) - fused_gdn_gating_kernel[grid](g, - A_log, - a, - dt_bias, - seq_len, - num_heads, - beta, - threshold, - 8, - num_warps=1) - return g + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + # reshape to chunks + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=0) + + # chunk decay + g = g.cumsum(dim=-1) + decay_mask = ((g.unsqueeze(-1) - + g.unsqueeze(-2)).tril().exp().float()).tril() + attn = -( + (k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_recurrent_state = (torch.zeros(batch_size, sequence_length, + k_head_dim, v_head_dim).to(value) if + initial_state is None else initial_state.to(value)) + + core_attn_out = torch.zeros_like(value) + mask = torch.triu(torch.ones(chunk_size, + chunk_size, + dtype=torch.bool, + device=query.device), + diagonal=1) + + # for each chunk + for i in range(0, tot_heads // chunk_size): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn = (q_i @ k_i.transpose(-1, -2) * + decay_mask[:, :, i]).masked_fill_(mask, 0) + v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state + core_attn_out[:, :, i] = attn_inter + attn @ v_new + last_recurrent_state = ( + last_recurrent_state * g[:, :, i, -1, None, None].exp() + + (k_i * + (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2) @ v_new) + + if not output_final_state: + last_recurrent_state = None + core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], + core_attn_out.shape[1], -1, + core_attn_out.shape[-1]) + core_attn_out = core_attn_out[:, :, :num_heads] + core_attn_out = core_attn_out.transpose(1, + 2).contiguous().to(initial_dtype) + return core_attn_out, last_recurrent_state diff --git a/vllm_ascend/ops/sigmoid_gating.py b/vllm_ascend/ops/sigmoid_gating.py index d599287..c99799c 100644 --- a/vllm_ascend/ops/sigmoid_gating.py +++ b/vllm_ascend/ops/sigmoid_gating.py @@ -97,16 +97,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) - # p_q = q + (bos * H + i_h) * K + o_k - # p_k = k + (bos * H + i_h) * K + o_k - # p_v = v + (bos * HV + i_hv) * V + o_v - # if IS_BETA_HEADWISE: - # p_beta = beta + (bos * HV + i_hv) * V + o_v - # else: - # p_beta = beta + bos * HV + i_hv - # p_g = g + bos * HV + i_hv - # p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v - mask_k = o_k < K mask_v = o_v < V mask_h = mask_k[:, None] & mask_v[None, :] @@ -170,13 +160,6 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - # p_q += H * K - # p_k += H * K - # p_o += HV * V - # p_v += HV * V - # p_g += HV - # p_beta += HV * (V if IS_BETA_HEADWISE else 1) - def fused_recurrent_gated_delta_rule_fwd( q: torch.Tensor, @@ -342,13 +325,11 @@ def fused_recurrent_gated_delta_rule( Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): Number of accepted tokens for each sequence during decoding. - Returns: o (torch.Tensor): Outputs of shape `[B, T, HV, V]`. final_state (torch.Tensor): Final state of shape `[N, HV, K, V]`. - Examples:: >>> import torch >>> import torch.nn.functional as F @@ -400,4 +381,4 @@ def fused_recurrent_gated_delta_rule( num_accepted_tokens, use_qk_l2norm_in_kernel, ) - return o, final_state + return o, final_state \ No newline at end of file diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 3f04b9b..37407b4 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -15,6 +15,11 @@ # limitations under the License. # +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + import vllm_ascend.patch.worker.patch_common.patch_triton + import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_triton.py b/vllm_ascend/patch/worker/patch_common/patch_triton.py new file mode 100644 index 0000000..8904054 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_triton.py @@ -0,0 +1,16 @@ +import vllm.model_executor.layers.fla.ops.chunk +import vllm.model_executor.layers.fla.ops.fused_recurrent +import vllm.model_executor.layers.fla.ops.layernorm_guard +import vllm.model_executor.layers.mamba.ops.causal_conv1d + +from vllm_ascend.ops.casual_conv1d import (causal_conv1d_fn, + causal_conv1d_update_npu) +from vllm_ascend.ops.fla import LayerNormFn, torch_chunk_gated_delta_rule +from vllm_ascend.ops.sigmoid_gating import \ + fused_recurrent_gated_delta_rule_fwd_kernel + +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal_conv1d_update_npu +vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn +vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel +vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn +vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule \ No newline at end of file