### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com)
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
245 lines
8.0 KiB
Python
245 lines
8.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
# mypy: ignore-errors
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp
|
|
|
|
_CONDITIONS = ("seq7168",)
|
|
|
|
|
|
@triton.heuristics(
|
|
{
|
|
"USE_G": lambda args: args["g"] is not None,
|
|
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
|
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
|
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
|
}
|
|
)
|
|
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
|
|
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
|
k,
|
|
v,
|
|
w,
|
|
v_new,
|
|
g,
|
|
h,
|
|
h0,
|
|
ht,
|
|
cu_seqlens,
|
|
chunk_offsets,
|
|
h_update,
|
|
T,
|
|
H,
|
|
Hg,
|
|
K,
|
|
V,
|
|
BT: tl.constexpr,
|
|
USE_G: tl.constexpr,
|
|
USE_INITIAL_STATE: tl.constexpr,
|
|
STORE_FINAL_STATE: tl.constexpr,
|
|
SAVE_NEW_VALUE: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_nh = tl.program_id(1)
|
|
i_n, i_h = i_nh // H, i_nh % H
|
|
T_max = 1 * T
|
|
if IS_VARLEN:
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
NT = tl.cdiv(T, BT)
|
|
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
|
else:
|
|
bos, eos = i_n * T, i_n * T + T
|
|
NT = tl.cdiv(T, BT)
|
|
boh = i_n * NT
|
|
|
|
stride_v = H * V
|
|
stride_k = Hg * K
|
|
stride_w = H * K
|
|
|
|
b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32)
|
|
b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32)
|
|
# create b_hupd_bv1 and b_hupd_bv2
|
|
|
|
v_start1 = 0
|
|
v_start2 = 64
|
|
|
|
offs_k = tl.arange(0, 128)[:, None]
|
|
offs_v1 = v_start1 + tl.arange(0, 64)[None, :]
|
|
offs_v2 = v_start2 + tl.arange(0, 64)[None, :]
|
|
mask_kv1 = (offs_k < K) & (offs_v1 < V)
|
|
mask_kv2 = (offs_k < K) & (offs_v2 < V)
|
|
|
|
# load initial state
|
|
if USE_INITIAL_STATE:
|
|
h0_ptr = h0 + i_nh * K * V
|
|
ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1
|
|
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1, other=0.0).to(tl.float32)
|
|
|
|
ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1
|
|
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2, other=0.0).to(tl.float32)
|
|
|
|
# main recurrence
|
|
for i_t in range(NT):
|
|
h_base = h + (boh + i_t) * H * K * V + i_h * K * V
|
|
|
|
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
|
tl.store(p_h1_bv1, b_h1_bv1.to(p_h1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
|
tl.store(p_h1_bv2, b_h1_bv2.to(p_h1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None]
|
|
offs_k_wv = tl.arange(0, 128)[None, :]
|
|
mask_w = (offs_t_wv < T) & (offs_k_wv < K)
|
|
|
|
w_base = w + bos * H * K + i_h * K
|
|
ptr_w = w_base + offs_t_wv * stride_w + offs_k_wv * 1
|
|
b_w = tl.load(ptr_w, mask=mask_w, other=0.0)
|
|
|
|
k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K
|
|
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT), (128, BT), (0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
|
|
v_new_base = v_new + bos * H * V + i_h * V
|
|
|
|
last_idx = min((i_t + 1) * BT, T) - 1
|
|
b_g_last = tl.load(g + bos + i_h * T_max + last_idx)
|
|
|
|
offs_t = i_t * BT + tl.arange(0, BT)
|
|
mask_t = offs_t < T
|
|
g_ptr = g + bos + i_h * T_max
|
|
b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0)
|
|
|
|
b_g = safe_exp(b_g_last - b_g)
|
|
b_g_last = tl.exp(b_g_last)
|
|
|
|
offs_t_v = (i_t * BT + tl.arange(0, BT))[:, None]
|
|
mask_v1 = (offs_t_v < T) & (offs_v1 < V)
|
|
|
|
v_base = v + bos * H * V + i_h * V
|
|
ptr_v1 = v_base + offs_t_v * stride_v + offs_v1 * 1
|
|
b_v1 = tl.load(ptr_v1, mask=mask_v1, other=0.0)
|
|
b_v_new1 = b_v1.to(tl.float32)
|
|
b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype))
|
|
|
|
if SAVE_NEW_VALUE:
|
|
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start1), (BT, 64), (1, 0))
|
|
tl.store(p_v_new1, b_v_new1.to(p_v_new1.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
if USE_G:
|
|
b_v_new1 = b_v_new1 * b_g[:, None]
|
|
b_h1_bv1 = b_h1_bv1 * b_g_last
|
|
|
|
b_v_new1 = b_v_new1.to(k.dtype.element_ty)
|
|
b_h1_bv1 += tl.dot(b_k, b_v_new1)
|
|
|
|
mask_v2 = (offs_t_v < T) & (offs_v2 < V)
|
|
ptr_v2 = v_base + offs_t_v * stride_v + offs_v2 * 1
|
|
b_v2 = tl.load(ptr_v2, mask=mask_v2, other=0.0)
|
|
b_v_new2 = b_v2.to(tl.float32)
|
|
b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype))
|
|
|
|
if SAVE_NEW_VALUE:
|
|
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1), (i_t * BT, v_start2), (BT, 64), (1, 0))
|
|
tl.store(p_v_new2, b_v_new2.to(p_v_new2.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
if USE_G:
|
|
b_v_new2 = b_v_new2 * b_g[:, None]
|
|
b_h1_bv2 = b_h1_bv2 * b_g_last
|
|
|
|
b_v_new2 = b_v_new2.to(k.dtype.element_ty)
|
|
b_h1_bv2 += tl.dot(b_k, b_v_new2)
|
|
|
|
# epilogue
|
|
if STORE_FINAL_STATE:
|
|
ht_ptr = ht + i_nh * K * V
|
|
|
|
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1), (128, 64), (1, 0))
|
|
tl.store(p_ht1_bv1, b_h1_bv1.to(p_ht1_bv1.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2), (128, 64), (1, 0))
|
|
tl.store(p_ht1_bv2, b_h1_bv2.to(p_ht1_bv2.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def chunk_gated_delta_rule_fwd_h(
|
|
k: torch.Tensor,
|
|
w: torch.Tensor,
|
|
u: torch.Tensor,
|
|
g: torch.Tensor | None = None,
|
|
initial_state: torch.Tensor | None = None,
|
|
output_final_state: bool = False,
|
|
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
|
save_new_value: bool = True,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
chunk_indices: torch.Tensor | None = None,
|
|
chunk_offsets: torch.Tensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
|
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
|
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
|
H = u.shape[-2]
|
|
BT = chunk_size
|
|
|
|
if cu_seqlens is not None and chunk_indices is None:
|
|
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
|
# N: the actual number of sequences in the batch with either equal or variable lengths
|
|
if cu_seqlens is None:
|
|
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
|
else:
|
|
if chunk_offsets is None:
|
|
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
|
N, NT, chunk_offsets = (
|
|
len(cu_seqlens) - 1,
|
|
len(chunk_indices),
|
|
chunk_offsets,
|
|
)
|
|
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
|
|
|
h = k.new_empty(B, NT, H, K, V)
|
|
h_update = k.new_empty(B, NT, H, K, K)
|
|
final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
|
|
|
v_new = torch.empty_like(u) if save_new_value else None
|
|
g = g.transpose(1, 2).contiguous()
|
|
|
|
def grid(meta):
|
|
return (1, N * H)
|
|
|
|
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
|
k=k,
|
|
v=u,
|
|
w=w,
|
|
v_new=v_new,
|
|
g=g,
|
|
h=h,
|
|
h0=initial_state,
|
|
ht=final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_offsets=chunk_offsets,
|
|
h_update=h_update,
|
|
T=T,
|
|
H=H,
|
|
Hg=Hg,
|
|
K=K,
|
|
V=V,
|
|
BT=BT,
|
|
num_warps=4,
|
|
num_stages=2,
|
|
)
|
|
return h, v_new, final_state
|