# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 # mypy: ignore-errors import os from vllm.triton_utils import tl, tldevice, triton if os.environ.get('FLA_USE_FAST_OPS', '0') == '1': div = tldevice.fast_dividef exp = tldevice.fast_expf log = tldevice.fast_logf log2 = tldevice.fast_log2f else: @triton.jit def div_normal(x, y): return x / y div = div_normal exp = tl.exp log = tl.log log2 = tl.log2 @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, "IS_CONTINUOUS_BATCHING": lambda args: args['ssm_state_indices'] is not None, "IS_SPEC_DECODING": lambda args: args['num_accepted_tokens'] is not None, }) @triton.jit(do_not_specialize=['N', 'T']) def fused_recurrent_gated_delta_rule_fwd_kernel( q, k, v, g, beta, o, h0, ht, cu_seqlens, ssm_state_indices, num_accepted_tokens, scale, N: tl.constexpr, # num of sequences T: tl.constexpr, # num of tokens B: tl.constexpr, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, stride_indices_tok: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, # whether to use initial state INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace IS_BETA_HEADWISE: tl. constexpr, # whether beta is headwise vector or scalar, USE_QK_L2NORM_IN_KERNEL: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, IS_KDA: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to( tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) all = T T = eos - bos else: bos, eos = i_n * T, i_n * T + T all = B * T if T == 0: # no tokens to process for this sequence return o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) mask_k = o_k < K mask_v = o_v < V mask_h = mask_k[:, None] & mask_v[None, :] b_h = tl.zeros([BK, BV], dtype=tl.float32) if USE_INITIAL_STATE: if IS_CONTINUOUS_BATCHING: if IS_SPEC_DECODING: i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token else: p_h0 = h0 + bos * HV * K * V p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) for i_t in range(0, T): p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t else: p_beta = beta + bos * HV + i_hv + HV * i_t if not IS_KDA: p_g = g + bos * HV + i_hv + HV * i_t else: p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_q = b_q * scale # [BK, BV] # b_h *= tl.exp(b_g) if not IS_KDA: b_g = tl.load(p_g).to(tl.float32) b_h *= exp(b_g) else: b_gk = tl.load(p_gk).to(tl.float32) b_h *= exp(b_gk[:, None]) # [BV] b_v -= tl.sum(b_h * b_k[:, None], 0) if IS_BETA_HEADWISE: b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) else: b_beta = tl.load(p_beta).to(tl.float32) b_v *= b_beta # [BK, BV] b_h += b_k[:, None] * b_v[None, :] # [BV] b_o = tl.sum(b_h * b_q[:, None], 0) tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) # keep the states for multi-query tokens if INPLACE_FINAL_STATE: p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token else: p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) @triton.heuristics({ 'USE_INITIAL_STATE': lambda args: args['h0'] is not None, 'IS_VARLEN': lambda args: args['cu_seqlens'] is not None, "IS_CONTINUOUS_BATCHING": lambda args: args['ssm_state_indices'] is not None, "IS_SPEC_DECODING": lambda args: args['num_accepted_tokens'] is not None, }) @triton.jit(do_not_specialize=['N', 'T']) def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0( q, k, v, g, beta, o, h0, ht, cu_seqlens, ssm_state_indices, num_accepted_tokens, scale, N: tl.constexpr, # num of sequences T: tl.constexpr, # num of tokens B: tl.constexpr, H: tl.constexpr, HV: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, stride_init_state_token: tl.constexpr, stride_final_state_token: tl.constexpr, stride_indices_seq: tl.constexpr, stride_indices_tok: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, # whether to use initial state INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace IS_BETA_HEADWISE: tl. constexpr, # whether beta is headwise vector or scalar, USE_QK_L2NORM_IN_KERNEL: tl.constexpr, IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) if IS_VARLEN: bos, eos = tl.load(cu_seqlens + i_n).to( tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) all = T T = eos - bos else: bos, eos = i_n * T, i_n * T + T all = B * T if T == 0: # no tokens to process for this sequence return o_k = i_k * BK + tl.arange(0, BK) o_v = i_v * BV + tl.arange(0, BV) mask_k = o_k < K mask_v = o_v < V mask_h = mask_k[:, None] & mask_v[None, :] b_h = tl.zeros([BK, BV], dtype=tl.float32) if USE_INITIAL_STATE: if IS_CONTINUOUS_BATCHING: if IS_SPEC_DECODING: i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 else: i_t = 0 p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token else: p_h0 = h0 + bos * HV * K * V p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) for i_t in range(0, T): p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t if IS_BETA_HEADWISE: p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t else: p_beta = beta + bos * HV + i_hv + HV * i_t p_g = g + bos * HV + i_hv + HV * i_t p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_q = b_q * scale # [BK, BV] # b_h *= tl.exp(b_g) b_h *= exp(b_g) # [BV] b_v -= tl.sum(b_h * b_k[:, None], 0) if IS_BETA_HEADWISE: b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) else: b_beta = tl.load(p_beta).to(tl.float32) b_v *= b_beta # [BK, BV] b_h += b_k[:, None] * b_v[None, :] # [BV] b_o = tl.sum(b_h * b_q[:, None], 0) tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) # keep the states for multi-query tokens if INPLACE_FINAL_STATE: p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token else: p_ht = ht + (bos + i_t) * stride_final_state_token p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)