212 lines
6.7 KiB
Python
212 lines
6.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
|
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
|
|
|
# ruff: noqa: E501,SIM102
|
|
|
|
import torch
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
|
num_stages=3,
|
|
num_warps=8,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=4,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
|
num_stages=5,
|
|
num_warps=2,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
|
num_stages=5,
|
|
num_warps=2,
|
|
),
|
|
triton.Config(
|
|
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
|
num_stages=4,
|
|
num_warps=2,
|
|
),
|
|
],
|
|
key=["chunk_size", "K", "IS_CAUSAL"],
|
|
)
|
|
@triton.jit
|
|
def _bmm_chunk_fwd_kernel(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
out_ptr,
|
|
cu_chunk_seqlens_ptr,
|
|
# Matrix dimensions
|
|
seqlen,
|
|
chunk_size: tl.constexpr,
|
|
K: tl.constexpr,
|
|
ngroups: tl.constexpr,
|
|
stride_a_seqlen: tl.int64,
|
|
stride_a_head: tl.int64,
|
|
stride_ak: tl.constexpr,
|
|
stride_b_seqlen: tl.int64,
|
|
stride_b_head: tl.int64,
|
|
stride_bk: tl.constexpr,
|
|
stride_out_chunk: tl.int64,
|
|
stride_out_head: tl.int64,
|
|
stride_outm: tl.int64,
|
|
stride_outn: tl.constexpr,
|
|
# Meta-parameters
|
|
IS_CAUSAL: tl.constexpr,
|
|
dot_dtype: tl.constexpr,
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
):
|
|
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
|
pid_c = pid_ch // ngroups
|
|
pid_h = pid_ch - pid_c * ngroups
|
|
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
|
pid_m = tl.program_id(axis=0) // num_pid_n
|
|
pid_n = tl.program_id(axis=0) % num_pid_n
|
|
if IS_CAUSAL:
|
|
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
|
return
|
|
|
|
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
|
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
|
|
|
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
|
|
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
|
|
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
|
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
|
|
|
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
# compute a * b.T
|
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
a = tl.load(
|
|
a_ptrs,
|
|
mask=(offs_m[:, None] < chunk_size_limit)
|
|
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
|
other=0.0,
|
|
).to(dot_dtype)
|
|
b = tl.load(
|
|
b_ptrs,
|
|
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
|
|
& (offs_n[None, :] < chunk_size_limit),
|
|
other=0.0,
|
|
).to(dot_dtype)
|
|
acc += tl.dot(a, b)
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
|
|
out = acc.to(out_ptr.dtype.element_ty)
|
|
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
|
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
|
tl.store(
|
|
out_ptrs,
|
|
out,
|
|
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
|
|
)
|
|
|
|
|
|
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
|
|
"""
|
|
Argument:
|
|
a: (seqlen, ngroups, k)
|
|
b: (seqlen, ngroups, k)
|
|
chunk_size: int
|
|
cu_chunk_seq_lens: (nchunks+1,)
|
|
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
|
guaranteed to be correct.
|
|
Return:
|
|
out: (nchunks, ngroups, chunk_size, chunk_size)
|
|
"""
|
|
seqlen, ngroups, k = a.shape
|
|
assert b.shape == a.shape
|
|
if a.stride(-1) != 1 and a.stride(0) != 1:
|
|
a = a.contiguous()
|
|
if b.stride(-1) != 1 and b.stride(0) != 1:
|
|
b = b.contiguous()
|
|
|
|
nchunks = len(cu_chunk_seqlens) - 1
|
|
# Allocates output.
|
|
out_dtype = a.dtype if output_dtype is None else output_dtype
|
|
out = torch.empty(
|
|
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
|
|
)
|
|
dot_dtype = (
|
|
tl.bfloat16
|
|
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
|
|
else (
|
|
tl.float16
|
|
if a.dtype == torch.float16 or b.dtype == torch.float16
|
|
else tl.float32
|
|
)
|
|
)
|
|
grid = lambda META: (
|
|
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
|
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
|
|
nchunks * ngroups,
|
|
)
|
|
with torch.cuda.device(a.device.index):
|
|
_bmm_chunk_fwd_kernel[grid](
|
|
a_ptr=a,
|
|
b_ptr=b,
|
|
out_ptr=out,
|
|
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
|
seqlen=seqlen,
|
|
chunk_size=chunk_size,
|
|
K=k,
|
|
ngroups=ngroups,
|
|
stride_a_seqlen=a.stride(0),
|
|
stride_a_head=a.stride(1),
|
|
stride_ak=a.stride(2),
|
|
stride_b_seqlen=b.stride(0),
|
|
stride_b_head=b.stride(1),
|
|
stride_bk=b.stride(2),
|
|
stride_out_chunk=out.stride(0),
|
|
stride_out_head=out.stride(1),
|
|
stride_outm=out.stride(-2),
|
|
stride_outn=out.stride(-1),
|
|
IS_CAUSAL=causal,
|
|
dot_dtype=dot_dtype,
|
|
)
|
|
return out
|