kernel: support slightly faster merge_state_v2 cuda kernel (#5381)
This commit is contained in:
@@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
set(SOURCES
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/merge_attn_states.cu"
|
||||
"csrc/attention/cutlass_mla_kernel.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
|
||||
201
sgl-kernel/csrc/attention/merge_attn_states.cu
Normal file
201
sgl-kernel/csrc/attention/merge_attn_states.cu
Normal file
@@ -0,0 +1,201 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <optional>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
// Helper functions to convert between different data types
|
||||
// (float, half, bfloat16) for the merge attention states kernel.
|
||||
inline __device__ float to_float(float u) {
|
||||
return u;
|
||||
}
|
||||
inline __device__ float to_float(half u) {
|
||||
return __half2float(u);
|
||||
}
|
||||
inline __device__ float to_float(__nv_bfloat16 u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
inline __device__ void from_float(float& d, float s) {
|
||||
d = s;
|
||||
}
|
||||
inline __device__ void from_float(half& d, float s) {
|
||||
d = __float2half(s);
|
||||
}
|
||||
inline __device__ void from_float(__nv_bfloat16& d, float s) {
|
||||
d = __float2bfloat16(s);
|
||||
}
|
||||
|
||||
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
|
||||
template <typename scalar_t, const uint NUM_THREADS>
|
||||
__global__ void merge_attn_states_kernel(
|
||||
scalar_t* output,
|
||||
float* output_lse,
|
||||
const scalar_t* prefix_output,
|
||||
const float* prefix_lse,
|
||||
const scalar_t* suffix_output,
|
||||
const float* suffix_lse,
|
||||
const uint num_tokens,
|
||||
const uint num_heads,
|
||||
const uint head_size) {
|
||||
using pack_128b_t = uint4;
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
|
||||
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
|
||||
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
if (global_idx >= token_head_threads) return;
|
||||
|
||||
// global_idx -> token_idx + head_idx + pack_idx
|
||||
const uint token_head_idx = global_idx / threads_per_head;
|
||||
const uint pack_idx = global_idx % threads_per_head;
|
||||
|
||||
const uint token_idx = token_head_idx / num_heads;
|
||||
const uint head_idx = token_head_idx % num_heads;
|
||||
|
||||
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
|
||||
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
|
||||
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
|
||||
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
|
||||
scalar_t* output_head_ptr = output + head_offset;
|
||||
|
||||
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
|
||||
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
|
||||
float p_lse = prefix_lse[token_idx * num_heads + head_idx];
|
||||
float s_lse = suffix_lse[token_idx * num_heads + head_idx];
|
||||
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
|
||||
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
|
||||
|
||||
const float max_lse = fmaxf(p_lse, s_lse);
|
||||
p_lse = p_lse - max_lse;
|
||||
s_lse = s_lse - max_lse;
|
||||
const float p_se = expf(p_lse);
|
||||
const float s_se = expf(s_lse);
|
||||
const float out_se = p_se + s_se;
|
||||
const float p_scale = p_se / out_se;
|
||||
const float s_scale = s_se / out_se;
|
||||
|
||||
if (pack_offset < head_size) {
|
||||
// Pack 128b load
|
||||
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
|
||||
pack_128b_t o_out_pack;
|
||||
|
||||
#pragma unroll
|
||||
for (uint i = 0; i < pack_size; ++i) {
|
||||
// Always use float for FMA to keep high precision.
|
||||
// half(uint16_t), bfloat16, float -> float.
|
||||
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
|
||||
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
|
||||
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
|
||||
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
|
||||
// float -> half(uint16_t), bfloat16, float.
|
||||
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
|
||||
}
|
||||
|
||||
// Pack 128b storage
|
||||
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
|
||||
}
|
||||
// We only need to write to output_lse once per head.
|
||||
if (output_lse != nullptr && pack_idx == 0) {
|
||||
float out_lse = logf(out_se) + max_lse;
|
||||
output_lse[token_idx * num_heads + head_idx] = out_lse;
|
||||
}
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
// the output data type. The FN is a macro that calls a function with
|
||||
// template<typename scalar_t>.
|
||||
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
|
||||
{ \
|
||||
if (scalar_dtype == at::ScalarType::Float) { \
|
||||
fn(float); \
|
||||
} else if (scalar_dtype == at::ScalarType::Half) { \
|
||||
fn(half); \
|
||||
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
|
||||
fn(__nv_bfloat16); \
|
||||
} else { \
|
||||
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
|
||||
{ \
|
||||
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
|
||||
reinterpret_cast<scalar_t*>(output.data_ptr()), \
|
||||
reinterpret_cast<float*>(output_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
|
||||
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
|
||||
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
|
||||
num_tokens, \
|
||||
num_heads, \
|
||||
head_size); \
|
||||
}
|
||||
|
||||
/*@brief Merges the attention states from prefix and suffix
|
||||
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
|
||||
*
|
||||
* @param output [n,h,d] The output tensor to store the merged attention states.
|
||||
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
|
||||
* @param prefix_output [n,h,d] The prefix attention states.
|
||||
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
|
||||
* states.
|
||||
* @param suffix_output [n,h,d] The suffix attention states.
|
||||
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
|
||||
* states.
|
||||
*/
|
||||
template <typename scalar_t>
|
||||
void merge_attn_states_launcher(
|
||||
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS]
|
||||
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS]
|
||||
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS]
|
||||
) {
|
||||
constexpr uint NUM_THREADS = 128;
|
||||
const uint num_tokens = output.size(0);
|
||||
const uint num_heads = output.size(1);
|
||||
const uint head_size = output.size(2);
|
||||
const uint pack_size = 16 / sizeof(scalar_t);
|
||||
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
|
||||
// Process one pack elements per thread. for float, the
|
||||
// pack_size is 4 for half/bf16, the pack_size is 8.
|
||||
const uint threads_per_head = head_size / pack_size;
|
||||
const uint total_threads = num_tokens * num_heads * threads_per_head;
|
||||
|
||||
dim3 block(NUM_THREADS);
|
||||
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
|
||||
|
||||
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
|
||||
}
|
||||
|
||||
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
|
||||
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
|
||||
|
||||
void merge_state_v2(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
|
||||
// Input tensors must be contiguous
|
||||
CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim)
|
||||
CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads)
|
||||
CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim)
|
||||
CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads)
|
||||
// v_merged output (seq_len, num_heads, head_dim)
|
||||
// s_merged output_lse (seq_len, num_heads)
|
||||
auto device = v_a.device();
|
||||
CHECK_EQ(s_a.device(), device);
|
||||
CHECK_EQ(v_b.device(), device);
|
||||
CHECK_EQ(s_b.device(), device);
|
||||
CHECK_DIM(3, v_a);
|
||||
CHECK_DIM(2, s_a);
|
||||
CHECK_DIM(3, v_b);
|
||||
CHECK_DIM(2, s_b);
|
||||
CHECK_SHAPE(v_a, v_b);
|
||||
CHECK_SHAPE(s_a, s_b);
|
||||
CHECK_EQ(v_a.size(0), s_a.size(0));
|
||||
CHECK_EQ(v_a.size(1), s_b.size(1));
|
||||
DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
|
||||
}
|
||||
@@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||
m.impl("merge_state", torch::kCUDA, &merge_state);
|
||||
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
|
||||
m.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||
"page_table, Tensor workspace) -> ()");
|
||||
|
||||
@@ -89,6 +89,8 @@ void lightning_attention_decode(
|
||||
torch::Tensor new_kv);
|
||||
void merge_state(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||
void merge_state_v2(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope_and_q_pe,
|
||||
|
||||
@@ -16,6 +16,7 @@ from sgl_kernel.attention import (
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
v_merged = torch.empty_like(v_a)
|
||||
s_merged = torch.empty_like(s_a)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def merge_state_v2(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
|
||||
# does not support the FP8 data type and non - CUDA devices.
|
||||
# It may be necessary to fall back to using the Triton kernel.
|
||||
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope_and_q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
396
sgl-kernel/tests/test_merge_state_v2.py
Normal file
396
sgl-kernel/tests/test_merge_state_v2.py
Normal file
@@ -0,0 +1,396 @@
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import merge_state, merge_state_v2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_state_kernel(
|
||||
output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
|
||||
output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
|
||||
prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
|
||||
prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
|
||||
suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
|
||||
suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
|
||||
HEAD_SIZE: tl.constexpr,
|
||||
PADDED_HEAD_SIZE: tl.constexpr,
|
||||
OUTPUT_LSE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
num_tokens = tl.num_programs(0)
|
||||
head_idx = tl.program_id(1)
|
||||
num_heads = tl.num_programs(1)
|
||||
|
||||
p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
|
||||
s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
|
||||
p_lse = float("-inf") if p_lse == float("inf") else p_lse
|
||||
s_lse = float("-inf") if s_lse == float("inf") else s_lse
|
||||
|
||||
max_lse = tl.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
out_se = tl.exp(p_lse) + tl.exp(s_lse)
|
||||
|
||||
if OUTPUT_LSE:
|
||||
out_lse = tl.log(out_se) + max_lse
|
||||
tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
|
||||
|
||||
head_arange = tl.arange(0, PADDED_HEAD_SIZE)
|
||||
head_mask = head_arange < HEAD_SIZE
|
||||
p_out = tl.load(
|
||||
prefix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
s_out = tl.load(
|
||||
suffix_output
|
||||
+ token_idx * num_heads * HEAD_SIZE
|
||||
+ head_idx * HEAD_SIZE
|
||||
+ head_arange,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
p_scale = tl.exp(p_lse) / out_se
|
||||
s_scale = tl.exp(s_lse) / out_se
|
||||
out = p_out * p_scale + s_out * s_scale
|
||||
tl.store(
|
||||
output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
|
||||
out,
|
||||
mask=head_mask,
|
||||
)
|
||||
|
||||
|
||||
def merge_state_triton(
|
||||
prefix_output: torch.Tensor,
|
||||
prefix_lse: torch.Tensor,
|
||||
suffix_output: torch.Tensor,
|
||||
suffix_lse: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_lse: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
num_tokens = output.shape[0]
|
||||
num_query_heads = output.shape[1]
|
||||
head_size = output.shape[2]
|
||||
padded_head_size = triton.next_power_of_2(head_size)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if output is None:
|
||||
output = torch.empty_like(prefix_output)
|
||||
if output_lse is None:
|
||||
output_lse = torch.empty_like(prefix_lse)
|
||||
|
||||
merge_state_kernel[(num_tokens, num_query_heads)](
|
||||
output,
|
||||
output_lse,
|
||||
prefix_output,
|
||||
prefix_lse,
|
||||
suffix_output,
|
||||
suffix_lse,
|
||||
head_size,
|
||||
padded_head_size,
|
||||
output_lse is not None,
|
||||
)
|
||||
return output, output_lse
|
||||
|
||||
|
||||
# Naive PyTorch Implements of Merge Attn States
|
||||
def merge_state_torch(
|
||||
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
prefix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
|
||||
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
suffix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS]
|
||||
output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
|
||||
output_lse: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS]
|
||||
):
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if output is None:
|
||||
output = torch.empty_like(prefix_output)
|
||||
if output_lse is None:
|
||||
output_lse = torch.empty_like(prefix_lse)
|
||||
p_lse = prefix_lse
|
||||
s_lse = suffix_lse
|
||||
# inf -> -inf
|
||||
p_lse[p_lse == torch.inf] = -torch.inf
|
||||
s_lse[s_lse == torch.inf] = -torch.inf
|
||||
# max_lse [NUM_HEADS, NUM_TOKENS]
|
||||
max_lse = torch.maximum(p_lse, s_lse)
|
||||
p_lse = p_lse - max_lse
|
||||
s_lse = s_lse - max_lse
|
||||
p_lse_exp = torch.exp(p_lse)
|
||||
s_lse_exp = torch.exp(s_lse)
|
||||
out_se = p_lse_exp + s_lse_exp
|
||||
if output_lse is not None:
|
||||
output_lse = torch.log(out_se) + max_lse
|
||||
p_scale = p_lse_exp / out_se
|
||||
s_scale = s_lse_exp / out_se
|
||||
p_scale = p_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
s_scale = s_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
|
||||
output = prefix_output * p_scale + suffix_output * s_scale
|
||||
return output, output_lse
|
||||
|
||||
|
||||
NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536]
|
||||
NUM_QUERY_HEADS = [8, 16, 32]
|
||||
HEAD_SIZES = [32, 48, 64, 128, 256]
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
|
||||
all_case_info: list[tuple] = []
|
||||
|
||||
|
||||
def generate_markdown_table():
|
||||
global all_case_info
|
||||
table_header = (
|
||||
"| tokens | heads | headsize | dtype "
|
||||
"| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|"
|
||||
)
|
||||
table_separator = (
|
||||
"| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |"
|
||||
)
|
||||
|
||||
def shortly_dtype(dtype: torch.dtype) -> str:
|
||||
return str(dtype).removeprefix("torch.")
|
||||
|
||||
def shortly_device(device: str) -> str:
|
||||
return device.removeprefix("NVIDIA").strip()
|
||||
|
||||
print(table_header)
|
||||
print(table_separator)
|
||||
for info in all_case_info:
|
||||
(
|
||||
num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype,
|
||||
device,
|
||||
time_torch,
|
||||
time_triton,
|
||||
time_v1,
|
||||
time_v2,
|
||||
) = info
|
||||
dtype = shortly_dtype(dtype)
|
||||
device = shortly_device(device)
|
||||
improved_triton = time_triton / time_v2
|
||||
improved_v1 = time_v1 / time_v2
|
||||
print(
|
||||
f"| {num_tokens} | {num_heads} | {head_size} "
|
||||
f"| {dtype} | {device} | {time_torch:.4f}ms "
|
||||
f"| {time_triton:.4f}ms "
|
||||
f"| {time_v1:.4f}ms "
|
||||
f"| {time_v2:.4f}ms "
|
||||
f"| {improved_triton:.4f}x "
|
||||
f"| {improved_v1:.4f}x |"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
|
||||
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("output_dtype", DTYPES)
|
||||
@torch.inference_mode()
|
||||
def test_merge_attn_states(
|
||||
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
|
||||
):
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip(
|
||||
"Currently only support compare triton merge_attn_states "
|
||||
"with custom cuda merge_attn_states kernel"
|
||||
)
|
||||
|
||||
NUM_TOKENS = num_tokens
|
||||
NUM_HEADS = num_query_heads
|
||||
HEAD_SIZE = head_size
|
||||
|
||||
print(
|
||||
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
|
||||
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
|
||||
f"Device: {torch.cuda.get_device_name()}"
|
||||
)
|
||||
|
||||
# prefix_lse and suffix_lse contain inf and normal values
|
||||
prefix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
|
||||
suffix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda")
|
||||
|
||||
# Generate boolean masks
|
||||
mask_prefix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
|
||||
mask_suffix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1
|
||||
# Ensure that the same position is not True at the same time
|
||||
combined_mask = torch.logical_and(mask_prefix, mask_suffix)
|
||||
mask_prefix = torch.logical_and(mask_prefix, ~combined_mask)
|
||||
mask_suffix = torch.logical_and(mask_suffix, ~combined_mask)
|
||||
|
||||
prefix_lse[mask_prefix] = float("inf")
|
||||
suffix_lse[mask_suffix] = float("inf")
|
||||
|
||||
# Other input tensors (need to be initialized but
|
||||
# no actual calculation needed)
|
||||
output = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
output_lse = torch.zeros(
|
||||
(NUM_TOKENS, NUM_HEADS), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
prefix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
suffix_output = torch.randn(
|
||||
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda"
|
||||
)
|
||||
|
||||
warmup_times = 2
|
||||
repeat_times = 20
|
||||
|
||||
def perf_kernel_fn(
|
||||
output_fn: torch.Tensor,
|
||||
output_lse_fn: torch.Tensor,
|
||||
kernel_fn: callable,
|
||||
fn_type: str = "torch",
|
||||
):
|
||||
# Avoid inplace inf -> -inf, we have to use prefix_lse
|
||||
# and suffix_lse for other kernel.
|
||||
if fn_type == "torch":
|
||||
prefix_lse_ = prefix_lse.clone()
|
||||
suffix_lse_ = suffix_lse.clone()
|
||||
else:
|
||||
prefix_lse_ = prefix_lse
|
||||
suffix_lse_ = suffix_lse
|
||||
|
||||
if fn_type == "cuda_v1":
|
||||
# merge_state v1 kernel not support float32
|
||||
if output_dtype not in (torch.half, torch.bfloat16):
|
||||
return 0, output_fn, output_lse_fn
|
||||
|
||||
total_time = 0
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
try:
|
||||
for _ in range(warmup_times):
|
||||
output_fn, output_lse_fn = kernel_fn(
|
||||
prefix_output,
|
||||
prefix_lse_,
|
||||
suffix_output,
|
||||
suffix_lse_,
|
||||
output_fn,
|
||||
output_lse_fn,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for _ in range(repeat_times):
|
||||
start.record()
|
||||
output_fn, output_lse_fn = kernel_fn(
|
||||
prefix_output,
|
||||
prefix_lse_,
|
||||
suffix_output,
|
||||
suffix_lse_,
|
||||
output_fn,
|
||||
output_lse_fn,
|
||||
)
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
total_time += start.elapsed_time(end)
|
||||
|
||||
avg_time = total_time / repeat_times
|
||||
return avg_time, output_fn, output_lse_fn
|
||||
except Exception as e:
|
||||
return 0, output_fn, output_lse_fn
|
||||
|
||||
# 0. Run the Torch kernel
|
||||
output_torch = output.clone()
|
||||
output_lse_torch = output_lse.clone()
|
||||
time_torch, output_torch, output_lse_torch = perf_kernel_fn(
|
||||
output_torch, output_lse_torch, merge_state_torch, fn_type="torch"
|
||||
)
|
||||
|
||||
# 1. Run the Triton kernel
|
||||
output_ref_triton = output.clone()
|
||||
output_lse_ref_triton = output_lse.clone()
|
||||
time_triton, output_ref_triton, output_lse_ref_triton = perf_kernel_fn(
|
||||
output_ref_triton,
|
||||
output_lse_ref_triton,
|
||||
merge_state_triton,
|
||||
fn_type="triton",
|
||||
)
|
||||
|
||||
# 2. Run the merge_state V1 kernel
|
||||
output_v1 = output.clone()
|
||||
output_lse_v1 = output_lse.clone()
|
||||
time_v1, output_v1, output_lse_v1 = perf_kernel_fn(
|
||||
output_v1, output_lse_v1, merge_state, fn_type="cuda_v1"
|
||||
)
|
||||
|
||||
# 3. Run the merge_state V2 kernel
|
||||
output_v2 = output.clone()
|
||||
output_lse_v2 = output_lse.clone()
|
||||
time_v2, output_v2, output_lse_v2 = perf_kernel_fn(
|
||||
output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2"
|
||||
)
|
||||
|
||||
# 4. Performance compare
|
||||
improved = time_triton / time_v2
|
||||
print(f" Torch time: {time_torch:.6f}ms")
|
||||
print(f" Triton time: {time_triton:.6f}ms")
|
||||
print(f"CUDA v1 time: {time_v1:.6f}ms")
|
||||
print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x")
|
||||
print("-" * 100)
|
||||
|
||||
# 5. Correctness compare
|
||||
# Liger Kernel: Efficient Triton Kernels for LLM Training
|
||||
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
|
||||
# use rtol = 1e-2 for bfloat16.
|
||||
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
|
||||
|
||||
def diff(a: torch.Tensor, b: torch.Tensor):
|
||||
max_diff = torch.max(torch.abs(a.float() - b.float()))
|
||||
return max_diff
|
||||
|
||||
# Use Triton output as reference because we want to replace
|
||||
# the Triton kernel with custom CUDA kernel for merge attn
|
||||
# states operation.
|
||||
output_ref = output_ref_triton
|
||||
output_lse_ref = output_lse_ref_triton
|
||||
torch.testing.assert_close(
|
||||
output_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}")
|
||||
print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_v2)}")
|
||||
print(f"(CUDA v2 vs Triton): {diff(output_ref, output_v2)}")
|
||||
print("-" * 100)
|
||||
|
||||
torch.testing.assert_close(
|
||||
output_lse_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol
|
||||
)
|
||||
print("Output LSE all match, max abs diff:")
|
||||
print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}")
|
||||
print(f"(CUDA v2 vs Torch) : {diff(output_lse_torch, output_lse_v2)}")
|
||||
print(f"(CUDA v2 vs Triton): {diff(output_lse_ref, output_lse_v2)}")
|
||||
print("-" * 100)
|
||||
|
||||
print(
|
||||
"All output values test passed! All inf values "
|
||||
"are correctly replaced with -inf."
|
||||
)
|
||||
print("-" * 100)
|
||||
|
||||
device = torch.cuda.get_device_name()
|
||||
all_case_info.append(
|
||||
(
|
||||
NUM_TOKENS,
|
||||
NUM_HEADS,
|
||||
HEAD_SIZE,
|
||||
output_dtype,
|
||||
device,
|
||||
time_torch,
|
||||
time_triton,
|
||||
time_v1,
|
||||
time_v2,
|
||||
)
|
||||
)
|
||||
if len(all_case_info) == (
|
||||
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
|
||||
):
|
||||
generate_markdown_table()
|
||||
Reference in New Issue
Block a user