diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 583136269..8b056d4cc 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE set(SOURCES "csrc/allreduce/custom_all_reduce.cu" "csrc/attention/cascade.cu" + "csrc/attention/merge_attn_states.cu" "csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu" "csrc/elementwise/activation.cu" diff --git a/sgl-kernel/csrc/attention/merge_attn_states.cu b/sgl-kernel/csrc/attention/merge_attn_states.cu new file mode 100644 index 000000000..a3b405340 --- /dev/null +++ b/sgl-kernel/csrc/attention/merge_attn_states.cu @@ -0,0 +1,201 @@ +#include +#include + +#include +#include + +#include "pytorch_extension_utils.h" + +// Helper functions to convert between different data types +// (float, half, bfloat16) for the merge attention states kernel. +inline __device__ float to_float(float u) { + return u; +} +inline __device__ float to_float(half u) { + return __half2float(u); +} +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} +inline __device__ void from_float(float& d, float s) { + d = s; +} +inline __device__ void from_float(half& d, float s) { + d = __float2half(s); +} +inline __device__ void from_float(__nv_bfloat16& d, float s) { + d = __float2bfloat16(s); +} + +// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 +template +__global__ void merge_attn_states_kernel( + scalar_t* output, + float* output_lse, + const scalar_t* prefix_output, + const float* prefix_lse, + const scalar_t* suffix_output, + const float* suffix_lse, + const uint num_tokens, + const uint num_heads, + const uint head_size) { + using pack_128b_t = uint4; + const uint pack_size = 16 / sizeof(scalar_t); + const uint threads_per_head = head_size / pack_size; + + const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x; + const uint token_head_threads = num_tokens * num_heads * threads_per_head; + + if (global_idx >= token_head_threads) return; + + // global_idx -> token_idx + head_idx + pack_idx + const uint token_head_idx = global_idx / threads_per_head; + const uint pack_idx = global_idx % threads_per_head; + + const uint token_idx = token_head_idx / num_heads; + const uint head_idx = token_head_idx % num_heads; + + const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc. + const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size; + const scalar_t* prefix_head_ptr = prefix_output + head_offset; + const scalar_t* suffix_head_ptr = suffix_output + head_offset; + scalar_t* output_head_ptr = output + head_offset; + + // float p_lse = prefix_lse[head_idx * num_tokens + token_idx]; + // float s_lse = suffix_lse[head_idx * num_tokens + token_idx]; + float p_lse = prefix_lse[token_idx * num_heads + head_idx]; + float s_lse = suffix_lse[token_idx * num_heads + head_idx]; + p_lse = std::isinf(p_lse) ? -std::numeric_limits::infinity() : p_lse; + s_lse = std::isinf(s_lse) ? -std::numeric_limits::infinity() : s_lse; + + const float max_lse = fmaxf(p_lse, s_lse); + p_lse = p_lse - max_lse; + s_lse = s_lse - max_lse; + const float p_se = expf(p_lse); + const float s_se = expf(s_lse); + const float out_se = p_se + s_se; + const float p_scale = p_se / out_se; + const float s_scale = s_se / out_se; + + if (pack_offset < head_size) { + // Pack 128b load + pack_128b_t p_out_pack = reinterpret_cast(prefix_head_ptr)[pack_offset / pack_size]; + pack_128b_t s_out_pack = reinterpret_cast(suffix_head_ptr)[pack_offset / pack_size]; + pack_128b_t o_out_pack; + +#pragma unroll + for (uint i = 0; i < pack_size; ++i) { + // Always use float for FMA to keep high precision. + // half(uint16_t), bfloat16, float -> float. + const float p_out_f = to_float(reinterpret_cast(&p_out_pack)[i]); + const float s_out_f = to_float(reinterpret_cast(&s_out_pack)[i]); + // fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale) + const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale); + // float -> half(uint16_t), bfloat16, float. + from_float(reinterpret_cast(&o_out_pack)[i], o_out_f); + } + + // Pack 128b storage + reinterpret_cast(output_head_ptr)[pack_offset / pack_size] = o_out_pack; + } + // We only need to write to output_lse once per head. + if (output_lse != nullptr && pack_idx == 0) { + float out_lse = logf(out_se) + max_lse; + output_lse[token_idx * num_heads + head_idx] = out_lse; + } +} + +// The following macro is used to dispatch the conversion function based on +// the output data type. The FN is a macro that calls a function with +// template. +#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \ + { \ + if (scalar_dtype == at::ScalarType::Float) { \ + fn(float); \ + } else if (scalar_dtype == at::ScalarType::Half) { \ + fn(half); \ + } else if (scalar_dtype == at::ScalarType::BFloat16) { \ + fn(__nv_bfloat16); \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \ + } \ + } + +#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \ + { \ + merge_attn_states_kernel<<>>( \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(output_lse.data_ptr()), \ + reinterpret_cast(prefix_output.data_ptr()), \ + reinterpret_cast(prefix_lse.data_ptr()), \ + reinterpret_cast(suffix_output.data_ptr()), \ + reinterpret_cast(suffix_lse.data_ptr()), \ + num_tokens, \ + num_heads, \ + head_size); \ + } + +/*@brief Merges the attention states from prefix and suffix + * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d + * + * @param output [n,h,d] The output tensor to store the merged attention states. + * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. + * @param prefix_output [n,h,d] The prefix attention states. + * @param prefix_lse [n,h] The log-sum-exp values for the prefix attention + * states. + * @param suffix_output [n,h,d] The suffix attention states. + * @param suffix_lse [n,h] The log-sum-exp values for the suffix attention + * states. + */ +template +void merge_attn_states_launcher( + const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS] + const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS] + at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS] +) { + constexpr uint NUM_THREADS = 128; + const uint num_tokens = output.size(0); + const uint num_heads = output.size(1); + const uint head_size = output.size(2); + const uint pack_size = 16 / sizeof(scalar_t); + TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size); + // Process one pack elements per thread. for float, the + // pack_size is 4 for half/bf16, the pack_size is 8. + const uint threads_per_head = head_size / pack_size; + const uint total_threads = num_tokens * num_heads * threads_per_head; + + dim3 block(NUM_THREADS); + dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS); + + LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS); +} + +#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \ + { merge_attn_states_launcher(v_a, s_a, v_b, s_b, v_merged, s_merged); } + +void merge_state_v2( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { + // Input tensors must be contiguous + CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim) + CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads) + CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim) + CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads) + // v_merged output (seq_len, num_heads, head_dim) + // s_merged output_lse (seq_len, num_heads) + auto device = v_a.device(); + CHECK_EQ(s_a.device(), device); + CHECK_EQ(v_b.device(), device); + CHECK_EQ(s_b.device(), device); + CHECK_DIM(3, v_a); + CHECK_DIM(2, s_a); + CHECK_DIM(3, v_b); + CHECK_DIM(2, s_b); + CHECK_SHAPE(v_a, v_b); + CHECK_SHAPE(s_a, s_b); + CHECK_EQ(v_a.size(0), s_a.size(0)); + CHECK_EQ(v_a.size(1), s_b.size(1)); + DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER); +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index a8370d893..d3e0ffae8 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); m.impl("merge_state", torch::kCUDA, &merge_state); + m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); + m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " "page_table, Tensor workspace) -> ()"); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 64e530295..118a8ba05 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -89,6 +89,8 @@ void lightning_attention_decode( torch::Tensor new_kv); void merge_state( at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); +void merge_state_v2( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void cutlass_mla_decode( torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index f8a5b35e8..a6338ee5a 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -16,6 +16,7 @@ from sgl_kernel.attention import ( cutlass_mla_get_workspace_size, lightning_attention_decode, merge_state, + merge_state_v2, ) from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index b8d6bce75..d80a6fbbd 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Optional, Tuple import torch @@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): def merge_state( - v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor + v_a: torch.Tensor, + s_a: torch.Tensor, + v_b: torch.Tensor, + s_b: torch.Tensor, + v_merged: Optional[torch.Tensor] = None, + s_merged: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: s_a = s_a.to(torch.float32) s_b = s_b.to(torch.float32) - v_merged = torch.empty_like(v_a) - s_merged = torch.empty_like(s_a) + # Avoid creating new tensors if they are already provided + if v_merged is None: + v_merged = torch.empty_like(v_a) + if s_merged is None: + s_merged = torch.empty_like(s_a) torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) return v_merged, s_merged +def merge_state_v2( + v_a: torch.Tensor, + s_a: torch.Tensor, + v_b: torch.Tensor, + s_b: torch.Tensor, + v_merged: Optional[torch.Tensor] = None, + s_merged: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) + # TODO(DefTruth): Currently, the custom merge_attn_states kernel + # does not support the FP8 data type and non - CUDA devices. + # It may be necessary to fall back to using the Triton kernel. + + # Avoid creating new tensors if they are already provided + if v_merged is None: + v_merged = torch.empty_like(v_a) + if s_merged is None: + s_merged = torch.empty_like(s_a) + torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged) + return v_merged, s_merged + + def cutlass_mla_decode( q_nope_and_q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, diff --git a/sgl-kernel/tests/test_merge_state_v2.py b/sgl-kernel/tests/test_merge_state_v2.py new file mode 100644 index 000000000..f5c7a30dd --- /dev/null +++ b/sgl-kernel/tests/test_merge_state_v2.py @@ -0,0 +1,396 @@ +from typing import Optional + +import pytest +import torch +import triton +import triton.language as tl +from sgl_kernel import merge_state, merge_state_v2 + + +@triton.jit +def merge_state_kernel( + output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged + output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged + prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a + prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a + suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b + suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b + HEAD_SIZE: tl.constexpr, + PADDED_HEAD_SIZE: tl.constexpr, + OUTPUT_LSE: tl.constexpr, +): + token_idx = tl.program_id(0) + num_tokens = tl.num_programs(0) + head_idx = tl.program_id(1) + num_heads = tl.num_programs(1) + + p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx) + s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx) + p_lse = float("-inf") if p_lse == float("inf") else p_lse + s_lse = float("-inf") if s_lse == float("inf") else s_lse + + max_lse = tl.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + out_se = tl.exp(p_lse) + tl.exp(s_lse) + + if OUTPUT_LSE: + out_lse = tl.log(out_se) + max_lse + tl.store(output_lse + token_idx * num_heads + head_idx, out_lse) + + head_arange = tl.arange(0, PADDED_HEAD_SIZE) + head_mask = head_arange < HEAD_SIZE + p_out = tl.load( + prefix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + s_out = tl.load( + suffix_output + + token_idx * num_heads * HEAD_SIZE + + head_idx * HEAD_SIZE + + head_arange, + mask=head_mask, + ) + + p_scale = tl.exp(p_lse) / out_se + s_scale = tl.exp(s_lse) / out_se + out = p_out * p_scale + s_out * s_scale + tl.store( + output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange, + out, + mask=head_mask, + ) + + +def merge_state_triton( + prefix_output: torch.Tensor, + prefix_lse: torch.Tensor, + suffix_output: torch.Tensor, + suffix_lse: torch.Tensor, + output: Optional[torch.Tensor] = None, + output_lse: Optional[torch.Tensor] = None, +) -> None: + num_tokens = output.shape[0] + num_query_heads = output.shape[1] + head_size = output.shape[2] + padded_head_size = triton.next_power_of_2(head_size) + # Avoid creating new tensors if they are already provided + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + + merge_state_kernel[(num_tokens, num_query_heads)]( + output, + output_lse, + prefix_output, + prefix_lse, + suffix_output, + suffix_lse, + head_size, + padded_head_size, + output_lse is not None, + ) + return output, output_lse + + +# Naive PyTorch Implements of Merge Attn States +def merge_state_torch( + prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + prefix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS] + suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + suffix_lse: torch.Tensor, # [NUM_TOKENS, NUM_HEADS] + output: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] + output_lse: Optional[torch.Tensor] = None, # [NUM_TOKENS, NUM_HEADS] +): + # Avoid creating new tensors if they are already provided + if output is None: + output = torch.empty_like(prefix_output) + if output_lse is None: + output_lse = torch.empty_like(prefix_lse) + p_lse = prefix_lse + s_lse = suffix_lse + # inf -> -inf + p_lse[p_lse == torch.inf] = -torch.inf + s_lse[s_lse == torch.inf] = -torch.inf + # max_lse [NUM_HEADS, NUM_TOKENS] + max_lse = torch.maximum(p_lse, s_lse) + p_lse = p_lse - max_lse + s_lse = s_lse - max_lse + p_lse_exp = torch.exp(p_lse) + s_lse_exp = torch.exp(s_lse) + out_se = p_lse_exp + s_lse_exp + if output_lse is not None: + output_lse = torch.log(out_se) + max_lse + p_scale = p_lse_exp / out_se + s_scale = s_lse_exp / out_se + p_scale = p_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + s_scale = s_scale.unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1] + output = prefix_output * p_scale + suffix_output * s_scale + return output, output_lse + + +NUM_BATCH_TOKENS = [256, 512, 613, 1024, 1536] +NUM_QUERY_HEADS = [8, 16, 32] +HEAD_SIZES = [32, 48, 64, 128, 256] +DTYPES = [torch.half, torch.bfloat16] + +all_case_info: list[tuple] = [] + + +def generate_markdown_table(): + global all_case_info + table_header = ( + "| tokens | heads | headsize | dtype " + "| device | torch | triton | v1 | v2 | speedup(vs triton) | speedup(vs v1)|" + ) + table_separator = ( + "| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |" + ) + + def shortly_dtype(dtype: torch.dtype) -> str: + return str(dtype).removeprefix("torch.") + + def shortly_device(device: str) -> str: + return device.removeprefix("NVIDIA").strip() + + print(table_header) + print(table_separator) + for info in all_case_info: + ( + num_tokens, + num_heads, + head_size, + dtype, + device, + time_torch, + time_triton, + time_v1, + time_v2, + ) = info + dtype = shortly_dtype(dtype) + device = shortly_device(device) + improved_triton = time_triton / time_v2 + improved_v1 = time_v1 / time_v2 + print( + f"| {num_tokens} | {num_heads} | {head_size} " + f"| {dtype} | {device} | {time_torch:.4f}ms " + f"| {time_triton:.4f}ms " + f"| {time_v1:.4f}ms " + f"| {time_v2:.4f}ms " + f"| {improved_triton:.4f}x " + f"| {improved_v1:.4f}x |" + ) + + +@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS) +@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("output_dtype", DTYPES) +@torch.inference_mode() +def test_merge_attn_states( + num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype +): + if not torch.cuda.is_available(): + pytest.skip( + "Currently only support compare triton merge_attn_states " + "with custom cuda merge_attn_states kernel" + ) + + NUM_TOKENS = num_tokens + NUM_HEADS = num_query_heads + HEAD_SIZE = head_size + + print( + f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, " + f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, " + f"Device: {torch.cuda.get_device_name()}" + ) + + # prefix_lse and suffix_lse contain inf and normal values + prefix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda") + suffix_lse = torch.randn(NUM_TOKENS, NUM_HEADS, dtype=torch.float32, device="cuda") + + # Generate boolean masks + mask_prefix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1 + mask_suffix = torch.rand(NUM_TOKENS, NUM_HEADS) < 0.1 + # Ensure that the same position is not True at the same time + combined_mask = torch.logical_and(mask_prefix, mask_suffix) + mask_prefix = torch.logical_and(mask_prefix, ~combined_mask) + mask_suffix = torch.logical_and(mask_suffix, ~combined_mask) + + prefix_lse[mask_prefix] = float("inf") + suffix_lse[mask_suffix] = float("inf") + + # Other input tensors (need to be initialized but + # no actual calculation needed) + output = torch.zeros( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + output_lse = torch.zeros( + (NUM_TOKENS, NUM_HEADS), dtype=torch.float32, device="cuda" + ) + prefix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + suffix_output = torch.randn( + (NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="cuda" + ) + + warmup_times = 2 + repeat_times = 20 + + def perf_kernel_fn( + output_fn: torch.Tensor, + output_lse_fn: torch.Tensor, + kernel_fn: callable, + fn_type: str = "torch", + ): + # Avoid inplace inf -> -inf, we have to use prefix_lse + # and suffix_lse for other kernel. + if fn_type == "torch": + prefix_lse_ = prefix_lse.clone() + suffix_lse_ = suffix_lse.clone() + else: + prefix_lse_ = prefix_lse + suffix_lse_ = suffix_lse + + if fn_type == "cuda_v1": + # merge_state v1 kernel not support float32 + if output_dtype not in (torch.half, torch.bfloat16): + return 0, output_fn, output_lse_fn + + total_time = 0 + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + try: + for _ in range(warmup_times): + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) + torch.cuda.synchronize() + + for _ in range(repeat_times): + start.record() + output_fn, output_lse_fn = kernel_fn( + prefix_output, + prefix_lse_, + suffix_output, + suffix_lse_, + output_fn, + output_lse_fn, + ) + end.record() + torch.cuda.synchronize() + total_time += start.elapsed_time(end) + + avg_time = total_time / repeat_times + return avg_time, output_fn, output_lse_fn + except Exception as e: + return 0, output_fn, output_lse_fn + + # 0. Run the Torch kernel + output_torch = output.clone() + output_lse_torch = output_lse.clone() + time_torch, output_torch, output_lse_torch = perf_kernel_fn( + output_torch, output_lse_torch, merge_state_torch, fn_type="torch" + ) + + # 1. Run the Triton kernel + output_ref_triton = output.clone() + output_lse_ref_triton = output_lse.clone() + time_triton, output_ref_triton, output_lse_ref_triton = perf_kernel_fn( + output_ref_triton, + output_lse_ref_triton, + merge_state_triton, + fn_type="triton", + ) + + # 2. Run the merge_state V1 kernel + output_v1 = output.clone() + output_lse_v1 = output_lse.clone() + time_v1, output_v1, output_lse_v1 = perf_kernel_fn( + output_v1, output_lse_v1, merge_state, fn_type="cuda_v1" + ) + + # 3. Run the merge_state V2 kernel + output_v2 = output.clone() + output_lse_v2 = output_lse.clone() + time_v2, output_v2, output_lse_v2 = perf_kernel_fn( + output_v2, output_lse_v2, merge_state_v2, fn_type="cuda_v2" + ) + + # 4. Performance compare + improved = time_triton / time_v2 + print(f" Torch time: {time_torch:.6f}ms") + print(f" Triton time: {time_triton:.6f}ms") + print(f"CUDA v1 time: {time_v1:.6f}ms") + print(f"CUDA v2 time: {time_v2:.6f}ms, Performance: {improved:.5f}x") + print("-" * 100) + + # 5. Correctness compare + # Liger Kernel: Efficient Triton Kernels for LLM Training + # https://arxiv.org/pdf/2410.10989, 3.3 Correctness + # use rtol = 1e-2 for bfloat16. + rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3 + + def diff(a: torch.Tensor, b: torch.Tensor): + max_diff = torch.max(torch.abs(a.float() - b.float())) + return max_diff + + # Use Triton output as reference because we want to replace + # the Triton kernel with custom CUDA kernel for merge attn + # states operation. + output_ref = output_ref_triton + output_lse_ref = output_lse_ref_triton + torch.testing.assert_close( + output_v2.float(), output_ref.float(), atol=1e-3, rtol=rtol + ) + print("Output all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_torch, output_ref)}") + print(f"(CUDA v2 vs Torch) : {diff(output_torch, output_v2)}") + print(f"(CUDA v2 vs Triton): {diff(output_ref, output_v2)}") + print("-" * 100) + + torch.testing.assert_close( + output_lse_v2.float(), output_lse_ref.float(), atol=1e-3, rtol=rtol + ) + print("Output LSE all match, max abs diff:") + print(f"(Triton vs Torch) : {diff(output_lse_torch, output_lse_ref)}") + print(f"(CUDA v2 vs Torch) : {diff(output_lse_torch, output_lse_v2)}") + print(f"(CUDA v2 vs Triton): {diff(output_lse_ref, output_lse_v2)}") + print("-" * 100) + + print( + "All output values test passed! All inf values " + "are correctly replaced with -inf." + ) + print("-" * 100) + + device = torch.cuda.get_device_name() + all_case_info.append( + ( + NUM_TOKENS, + NUM_HEADS, + HEAD_SIZE, + output_dtype, + device, + time_torch, + time_triton, + time_v1, + time_v2, + ) + ) + if len(all_case_info) == ( + len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES) + ): + generate_markdown_table()