### What this PR does / why we need it?
mkdir triton package and move triton files
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
Signed-off-by: shiyuan680 <917935075@qq.com>
301 lines
9.8 KiB
Python
301 lines
9.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
# mypy: ignore-errors
|
|
|
|
import os
|
|
|
|
from vllm.triton_utils import tl, tldevice, triton
|
|
|
|
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
|
div = tldevice.fast_dividef
|
|
exp = tldevice.fast_expf
|
|
log = tldevice.fast_logf
|
|
log2 = tldevice.fast_log2f
|
|
else:
|
|
|
|
@triton.jit
|
|
def div_normal(x, y):
|
|
return x / y
|
|
|
|
div = div_normal
|
|
exp = tl.exp
|
|
log = tl.log
|
|
log2 = tl.log2
|
|
|
|
|
|
@triton.heuristics({
|
|
'USE_INITIAL_STATE':
|
|
lambda args: args['h0'] is not None,
|
|
'IS_VARLEN':
|
|
lambda args: args['cu_seqlens'] is not None,
|
|
"IS_CONTINUOUS_BATCHING":
|
|
lambda args: args['ssm_state_indices'] is not None,
|
|
"IS_SPEC_DECODING":
|
|
lambda args: args['num_accepted_tokens'] is not None,
|
|
})
|
|
@triton.jit(do_not_specialize=['N', 'T'])
|
|
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|
q,
|
|
k,
|
|
v,
|
|
g,
|
|
beta,
|
|
o,
|
|
h0,
|
|
ht,
|
|
cu_seqlens,
|
|
ssm_state_indices,
|
|
num_accepted_tokens,
|
|
scale,
|
|
N: tl.constexpr, # num of sequences
|
|
T: tl.constexpr, # num of tokens
|
|
B: tl.constexpr,
|
|
H: tl.constexpr,
|
|
HV: tl.constexpr,
|
|
K: tl.constexpr,
|
|
V: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
BV: tl.constexpr,
|
|
stride_init_state_token: tl.constexpr,
|
|
stride_final_state_token: tl.constexpr,
|
|
stride_indices_seq: tl.constexpr,
|
|
stride_indices_tok: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
|
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
|
IS_BETA_HEADWISE: tl.
|
|
constexpr, # whether beta is headwise vector or scalar,
|
|
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
|
IS_SPEC_DECODING: tl.constexpr,
|
|
IS_KDA: tl.constexpr,
|
|
):
|
|
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
i_n, i_hv = i_nh // HV, i_nh % HV
|
|
i_h = i_hv // (HV // H)
|
|
if IS_VARLEN:
|
|
bos, eos = tl.load(cu_seqlens + i_n).to(
|
|
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
|
all = T
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_n * T, i_n * T + T
|
|
all = B * T
|
|
|
|
if T == 0:
|
|
# no tokens to process for this sequence
|
|
return
|
|
|
|
o_k = i_k * BK + tl.arange(0, BK)
|
|
o_v = i_v * BV + tl.arange(0, BV)
|
|
|
|
mask_k = o_k < K
|
|
mask_v = o_v < V
|
|
mask_h = mask_k[:, None] & mask_v[None, :]
|
|
|
|
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
|
if USE_INITIAL_STATE:
|
|
if IS_CONTINUOUS_BATCHING:
|
|
if IS_SPEC_DECODING:
|
|
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
|
else:
|
|
i_t = 0
|
|
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
|
i_t).to(tl.int64) * stride_init_state_token
|
|
else:
|
|
p_h0 = h0 + bos * HV * K * V
|
|
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
|
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
|
|
|
for i_t in range(0, T):
|
|
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
|
|
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
|
|
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
|
|
|
if IS_BETA_HEADWISE:
|
|
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
|
else:
|
|
p_beta = beta + bos * HV + i_hv + HV * i_t
|
|
|
|
if not IS_KDA:
|
|
p_g = g + bos * HV + i_hv + HV * i_t
|
|
else:
|
|
p_gk = g + (bos * HV + i_hv + HV * i_t) * K + o_k
|
|
|
|
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
|
|
|
|
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
|
b_g = tl.load(p_g).to(tl.float32)
|
|
|
|
if USE_QK_L2NORM_IN_KERNEL:
|
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
|
b_q = b_q * scale
|
|
# [BK, BV]
|
|
# b_h *= tl.exp(b_g)
|
|
if not IS_KDA:
|
|
b_g = tl.load(p_g).to(tl.float32)
|
|
b_h *= exp(b_g)
|
|
else:
|
|
b_gk = tl.load(p_gk).to(tl.float32)
|
|
b_h *= exp(b_gk[:, None])
|
|
# [BV]
|
|
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
|
if IS_BETA_HEADWISE:
|
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
|
else:
|
|
b_beta = tl.load(p_beta).to(tl.float32)
|
|
b_v *= b_beta
|
|
# [BK, BV]
|
|
b_h += b_k[:, None] * b_v[None, :]
|
|
# [BV]
|
|
b_o = tl.sum(b_h * b_q[:, None], 0)
|
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
|
|
|
# keep the states for multi-query tokens
|
|
if INPLACE_FINAL_STATE:
|
|
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
|
i_t).to(tl.int64) * stride_final_state_token
|
|
else:
|
|
p_ht = ht + (bos + i_t) * stride_final_state_token
|
|
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
|
|
|
|
|
@triton.heuristics({
|
|
'USE_INITIAL_STATE':
|
|
lambda args: args['h0'] is not None,
|
|
'IS_VARLEN':
|
|
lambda args: args['cu_seqlens'] is not None,
|
|
"IS_CONTINUOUS_BATCHING":
|
|
lambda args: args['ssm_state_indices'] is not None,
|
|
"IS_SPEC_DECODING":
|
|
lambda args: args['num_accepted_tokens'] is not None,
|
|
})
|
|
@triton.jit(do_not_specialize=['N', 'T'])
|
|
def fused_recurrent_gated_delta_rule_fwd_kernel_0_11_0(
|
|
q,
|
|
k,
|
|
v,
|
|
g,
|
|
beta,
|
|
o,
|
|
h0,
|
|
ht,
|
|
cu_seqlens,
|
|
ssm_state_indices,
|
|
num_accepted_tokens,
|
|
scale,
|
|
N: tl.constexpr, # num of sequences
|
|
T: tl.constexpr, # num of tokens
|
|
B: tl.constexpr,
|
|
H: tl.constexpr,
|
|
HV: tl.constexpr,
|
|
K: tl.constexpr,
|
|
V: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
BV: tl.constexpr,
|
|
stride_init_state_token: tl.constexpr,
|
|
stride_final_state_token: tl.constexpr,
|
|
stride_indices_seq: tl.constexpr,
|
|
stride_indices_tok: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
|
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
|
IS_BETA_HEADWISE: tl.
|
|
constexpr, # whether beta is headwise vector or scalar,
|
|
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
|
IS_SPEC_DECODING: tl.constexpr,
|
|
):
|
|
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
i_n, i_hv = i_nh // HV, i_nh % HV
|
|
i_h = i_hv // (HV // H)
|
|
if IS_VARLEN:
|
|
bos, eos = tl.load(cu_seqlens + i_n).to(
|
|
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
|
all = T
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_n * T, i_n * T + T
|
|
all = B * T
|
|
|
|
if T == 0:
|
|
# no tokens to process for this sequence
|
|
return
|
|
|
|
o_k = i_k * BK + tl.arange(0, BK)
|
|
o_v = i_v * BV + tl.arange(0, BV)
|
|
|
|
mask_k = o_k < K
|
|
mask_v = o_v < V
|
|
mask_h = mask_k[:, None] & mask_v[None, :]
|
|
|
|
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
|
if USE_INITIAL_STATE:
|
|
if IS_CONTINUOUS_BATCHING:
|
|
if IS_SPEC_DECODING:
|
|
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
|
else:
|
|
i_t = 0
|
|
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
|
i_t).to(tl.int64) * stride_init_state_token
|
|
else:
|
|
p_h0 = h0 + bos * HV * K * V
|
|
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
|
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
|
|
|
for i_t in range(0, T):
|
|
p_q = q + (bos * H + i_h) * K + o_k + H * K * i_t
|
|
p_k = k + (bos * H + i_h) * K + o_k + H * K * i_t
|
|
p_v = v + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
|
if IS_BETA_HEADWISE:
|
|
p_beta = beta + (bos * HV + i_hv) * V + o_v + HV * V * i_t
|
|
else:
|
|
p_beta = beta + bos * HV + i_hv + HV * i_t
|
|
p_g = g + bos * HV + i_hv + HV * i_t
|
|
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + HV * V * i_t
|
|
|
|
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
|
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
|
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
|
b_g = tl.load(p_g).to(tl.float32)
|
|
|
|
if USE_QK_L2NORM_IN_KERNEL:
|
|
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
|
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
|
b_q = b_q * scale
|
|
# [BK, BV]
|
|
# b_h *= tl.exp(b_g)
|
|
b_h *= exp(b_g)
|
|
# [BV]
|
|
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
|
if IS_BETA_HEADWISE:
|
|
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
|
else:
|
|
b_beta = tl.load(p_beta).to(tl.float32)
|
|
b_v *= b_beta
|
|
# [BK, BV]
|
|
b_h += b_k[:, None] * b_v[None, :]
|
|
# [BV]
|
|
b_o = tl.sum(b_h * b_q[:, None], 0)
|
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
|
|
|
# keep the states for multi-query tokens
|
|
if INPLACE_FINAL_STATE:
|
|
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
|
i_t).to(tl.int64) * stride_final_state_token
|
|
else:
|
|
p_ht = ht + (bos + i_t) * stride_final_state_token
|
|
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
|
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|