144 lines
4.6 KiB
Python
144 lines
4.6 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from .index import prepare_chunk_indices
|
|
from .op import exp
|
|
|
|
|
|
|
|
|
|
@triton.heuristics({
|
|
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
|
'USE_G': lambda args: args['g_cumsum'] is not None
|
|
})
|
|
# @triton.autotune(
|
|
# configs=[
|
|
# triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
|
|
# for BK in [32, 64, 128] for num_warps in [2, 4, 8]
|
|
# for num_stages in [2, 3, 4]
|
|
# ],
|
|
# key=['H', 'K', 'BT', 'IS_VARLEN'],
|
|
# )
|
|
@triton.jit(do_not_specialize=['T'])
|
|
def chunk_scaled_dot_kkt_fwd_kernel(
|
|
k,
|
|
beta,
|
|
g_cumsum,
|
|
A,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
Hg: tl.constexpr,
|
|
K: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
USE_G: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
|
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
|
bos, eos = tl.load(cu_seqlens + i_n).to(
|
|
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
o_t = i_t * BT + tl.arange(0, BT)
|
|
#m_t = o_t < T
|
|
|
|
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
|
(i_t * BT, ), (BT, ), (0, ))
|
|
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
|
|
|
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
|
for i_k in range(tl.cdiv(K, BK)):
|
|
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
|
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
|
(1, 0))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_kb = b_k * b_beta[:, None]
|
|
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
|
|
|
if USE_G:
|
|
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
|
|
(i_t * BT, ), (BT, ), (0, ))
|
|
b_g = tl.load(p_g, boundary_check=(0, ))
|
|
b_g_diff = b_g[:, None] - b_g[None, :]
|
|
b_A = b_A * tl.exp(b_g_diff) # 使用了triton而非vllm中的exp
|
|
|
|
#m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
|
#b_A = tl.where(m_A, b_A, 0)
|
|
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
|
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
|
(i_t * BT, 0), (BT, BT), (1, 0))
|
|
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def chunk_scaled_dot_kkt_fwd(
|
|
k: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
g_cumsum: Optional[torch.Tensor] = None,
|
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
chunk_size: int = 64,
|
|
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
r"""
|
|
Compute beta * K * K^T.
|
|
|
|
Args:
|
|
k (torch.Tensor):
|
|
The key tensor of shape `[B, T, H, K]`.
|
|
beta (torch.Tensor):
|
|
The beta tensor of shape `[B, T, H]`.
|
|
g_cumsum (torch.Tensor):
|
|
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
|
Default: None
|
|
cu_seqlens (torch.LongTensor):
|
|
The cumulative sequence lengths of the input tensor.
|
|
Default: None
|
|
chunk_size (int):
|
|
The chunk size. Default: 64.
|
|
output_dtype (torch.dtype):
|
|
The dtype of the output tensor. Default: `torch.float32`
|
|
|
|
Returns:
|
|
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
|
"""
|
|
|
|
B, T, Hg, K = k.shape
|
|
|
|
H = beta.shape[-1]
|
|
BT = chunk_size
|
|
chunk_indices = prepare_chunk_indices(
|
|
cu_seqlens, BT) if cu_seqlens is not None else None
|
|
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
|
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
|
k=k,
|
|
beta=beta,
|
|
g_cumsum=g_cumsum,
|
|
A=A,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
Hg=Hg,
|
|
K=K,
|
|
BT=BT,
|
|
BK=64,
|
|
)
|
|
return A |