251 lines
9.5 KiB
Python
251 lines
9.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
|
from .op import exp
|
|
from .utils import is_nvidia_hopper, use_cuda_graph
|
|
|
|
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
|
|
|
|
|
@triton.heuristics(
|
|
{
|
|
"USE_G": lambda args: args["g"] is not None,
|
|
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
|
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
|
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
|
}
|
|
)
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|
k,
|
|
v,
|
|
w,
|
|
v_new,
|
|
g,
|
|
h,
|
|
h0,
|
|
ht,
|
|
cu_seqlens,
|
|
chunk_offsets,
|
|
T,
|
|
H: tl.constexpr,
|
|
Hg: tl.constexpr,
|
|
K: tl.constexpr,
|
|
V: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BV: tl.constexpr,
|
|
USE_G: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr,
|
|
STORE_FINAL_STATE: tl.constexpr,
|
|
SAVE_NEW_VALUE: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
|
i_n, i_h = i_nh // H, i_nh % H
|
|
|
|
if IS_VARLEN:
|
|
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
|
T = eos - bos
|
|
NT = tl.cdiv(T, BT)
|
|
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
|
else:
|
|
bos, eos = i_n * T, i_n * T + T
|
|
NT = tl.cdiv(T, BT)
|
|
boh = i_n * NT
|
|
|
|
# [BK, BV]
|
|
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
|
if K > 64:
|
|
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
|
if K > 128:
|
|
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
|
if K > 192:
|
|
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
|
|
|
# calculate offset
|
|
h += (boh * H + i_h) * K * V
|
|
v += (bos * H + i_h) * V
|
|
k += (bos * Hg + i_h // (H // Hg)) * K
|
|
w += (bos * H + i_h) * K
|
|
if SAVE_NEW_VALUE:
|
|
v_new += (bos * H + i_h) * V
|
|
stride_v = H * V
|
|
stride_h = H * K * V
|
|
stride_k = Hg * K
|
|
stride_w = H * K
|
|
if USE_INITIAL_STATE:
|
|
h0 = h0 + i_nh * K * V
|
|
if STORE_FINAL_STATE:
|
|
ht = ht + i_nh * K * V
|
|
|
|
# load initial state
|
|
if USE_INITIAL_STATE:
|
|
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
|
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
|
if K > 64:
|
|
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
|
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
|
if K > 128:
|
|
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
|
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
|
if K > 192:
|
|
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
|
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
|
|
|
# main recurrence
|
|
for i_t in range(NT):
|
|
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 64:
|
|
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 128:
|
|
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 192:
|
|
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
|
p_v_new = (
|
|
tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
|
if SAVE_NEW_VALUE
|
|
else None
|
|
)
|
|
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
|
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
|
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
|
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
|
if K > 64:
|
|
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
|
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
|
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
|
if K > 128:
|
|
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
|
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
|
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
|
if K > 192:
|
|
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
|
|
b_w = tl.load(p_w, boundary_check=(0, 1))
|
|
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
|
|
|
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
|
|
|
if SAVE_NEW_VALUE:
|
|
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
|
tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
if USE_G:
|
|
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
|
last_idx = min((i_t + 1) * BT, T) - 1
|
|
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
|
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
|
b_g = tl.load(p_g, boundary_check=(0,))
|
|
b_v_new = b_v_new * tl.where(m_t, tl.exp(b_g_last - b_g), 0)[:, None]
|
|
b_g_last = tl.exp(b_g_last)
|
|
b_h1 = b_h1 * b_g_last
|
|
if K > 64:
|
|
b_h2 = b_h2 * b_g_last
|
|
if K > 128:
|
|
b_h3 = b_h3 * b_g_last
|
|
if K > 192:
|
|
b_h4 = b_h4 * b_g_last
|
|
b_v_new = b_v_new.to(k.dtype.element_ty)
|
|
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_h1 += tl.dot(b_k, b_v_new)
|
|
if K > 64:
|
|
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_h2 += tl.dot(b_k, b_v_new)
|
|
if K > 128:
|
|
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_h3 += tl.dot(b_k, b_v_new)
|
|
if K > 192:
|
|
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_h4 += tl.dot(b_k, b_v_new)
|
|
|
|
# epilogue
|
|
if STORE_FINAL_STATE:
|
|
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 64:
|
|
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 128:
|
|
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
|
if K > 192:
|
|
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
|
|
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def chunk_gated_delta_rule_fwd_h(
|
|
k: torch.Tensor,
|
|
w: torch.Tensor,
|
|
u: torch.Tensor,
|
|
g: Optional[torch.Tensor] = None,
|
|
initial_state: Optional[torch.Tensor] = None,
|
|
output_final_state: bool = False,
|
|
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
|
save_new_value: bool = True,
|
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
|
H = u.shape[-2]
|
|
BT = chunk_size
|
|
|
|
chunk_indices = prepare_chunk_indices(
|
|
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
|
# N: the actual number of sequences in the batch with either equal or variable lengths
|
|
if cu_seqlens is None:
|
|
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
else:
|
|
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
|
|
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
|
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
|
|
|
h = k.new_empty(B, NT, H, K, V)
|
|
final_state = k.new_empty(
|
|
N, H, K, V, dtype=torch.float32) if output_final_state else None
|
|
|
|
v_new = torch.empty_like(u) if save_new_value else None
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(V, meta['BV']), N * H)
|
|
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
|
k=k,
|
|
v=u,
|
|
w=w,
|
|
v_new=v_new,
|
|
g=g,
|
|
h=h,
|
|
h0=initial_state,
|
|
ht=final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_offsets=chunk_offsets,
|
|
T=T,
|
|
H=H,
|
|
Hg=Hg,
|
|
K=K,
|
|
V=V,
|
|
BT=BT,
|
|
BV=64,
|
|
)
|
|
return h, v_new, final_state |