From b62e7e99b80e3df5f34916de95d0324200ee032d Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Sat, 12 Apr 2025 21:14:04 -0700 Subject: [PATCH] feat: adapt merge_state (#5337) --- .github/workflows/pr-test-sgl-kernel.yml | 8 +- sgl-kernel/CMakeLists.txt | 4 + sgl-kernel/csrc/attention/cascade.cu | 55 +++++++++ sgl-kernel/csrc/common_extension.cc | 2 + sgl-kernel/include/sgl_kernel_ops.h | 2 + sgl-kernel/python/sgl_kernel/__init__.py | 1 + sgl-kernel/python/sgl_kernel/attention.py | 17 ++- sgl-kernel/tests/test_merge_state.py | 138 ++++++++++++++++++++++ 8 files changed, 224 insertions(+), 3 deletions(-) create mode 100644 sgl-kernel/csrc/attention/cascade.cu create mode 100644 sgl-kernel/tests/test_merge_state.py diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml index d86e34701..51175b724 100644 --- a/.github/workflows/pr-test-sgl-kernel.yml +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -44,6 +44,12 @@ jobs: cuda-version: '12.8' name: Build Wheel (CUDA ${{ matrix.cuda-version }}) steps: + - name: Skip unnecessary builds on push to main + if: github.event_name == 'push' && (matrix.cuda-version == '11.8' || matrix.cuda-version == '12.8') + run: | + echo "Skipping CUDA ${{ matrix.cuda-version }} build on push to main" + exit 0 + - name: Cleanup run: | sudo rm -rf $GITHUB_WORKSPACE/* || true @@ -87,7 +93,7 @@ jobs: - name: Install run: | bash scripts/ci_install_dependency.sh - pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2 + pip3 install torch==2.5.1 && pip3 install pytest pip3 uninstall sgl-kernel -y || true pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps pip3 list | grep sgl-kernel diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 6283f0798..ab0b4853f 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -25,6 +25,8 @@ find_package(Torch REQUIRED) # clean Torch Flag clear_cuda_arches(CMAKE_FLAG) +set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON) + include(FetchContent) # cutlass @@ -104,6 +106,7 @@ set(SGL_KERNEL_CUDA_FLAGS "--expt-relaxed-constexpr" "-Xcompiler=-Wconversion" "-Xcompiler=-fno-strict-aliasing" + "--threads=16" ) option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF) @@ -160,6 +163,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE set(SOURCES "csrc/allreduce/custom_all_reduce.cu" + "csrc/attention/cascade.cu" "csrc/attention/cutlass_mla_kernel.cu" "csrc/attention/lightning_attention_decode_kernel.cu" "csrc/elementwise/activation.cu" diff --git a/sgl-kernel/csrc/attention/cascade.cu b/sgl-kernel/csrc/attention/cascade.cu new file mode 100644 index 000000000..9d49360dd --- /dev/null +++ b/sgl-kernel/csrc/attention/cascade.cu @@ -0,0 +1,55 @@ +// Adapted from +// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu + +#include +#include + +#include + +#include "pytorch_extension_utils.h" + +using namespace flashinfer; + +void merge_state( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) { + CHECK_INPUT(v_a); + CHECK_INPUT(s_a); + CHECK_INPUT(v_b); + CHECK_INPUT(s_b); + auto device = v_a.device(); + CHECK_EQ(s_a.device(), device); + CHECK_EQ(v_b.device(), device); + CHECK_EQ(s_b.device(), device); + CHECK_DIM(3, v_a); + CHECK_DIM(2, s_a); + CHECK_DIM(3, v_b); + CHECK_DIM(2, s_b); + CHECK_SHAPE(v_a, v_b); + CHECK_SHAPE(s_a, s_b); + CHECK_EQ(v_a.size(0), s_a.size(0)); + CHECK_EQ(v_a.size(1), s_b.size(1)); + unsigned int seq_len = v_a.size(0); + unsigned int num_heads = v_a.size(1); + unsigned int head_dim = v_a.size(2); + + const c10::cuda::OptionalCUDAGuard device_guard(v_a.device()); + auto stream = at::cuda::getCurrentCUDAStream(); + + bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] { + cudaError_t status = MergeState( + static_cast(v_a.data_ptr()), + static_cast(s_a.data_ptr()), + static_cast(v_b.data_ptr()), + static_cast(s_b.data_ptr()), + static_cast(v_merged.data_ptr()), + static_cast(s_merged.data_ptr()), + seq_len, + num_heads, + head_dim, + stream); + TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status)); + return true; + }); + + TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type"); +} diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index c2086aa5b..a8370d893 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -45,6 +45,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " "new_kv) -> ()"); m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()"); + m.impl("merge_state", torch::kCUDA, &merge_state); m.def( "cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor " "page_table, Tensor workspace) -> ()"); diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index 07046800d..64e530295 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -87,6 +87,8 @@ void lightning_attention_decode( const torch::Tensor& slope, torch::Tensor output, torch::Tensor new_kv); +void merge_state( + at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged); void cutlass_mla_decode( torch::Tensor const& out, torch::Tensor const& q_nope_and_q_pe, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 2d6bc0d56..f8a5b35e8 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -15,6 +15,7 @@ from sgl_kernel.attention import ( cutlass_mla_decode, cutlass_mla_get_workspace_size, lightning_attention_decode, + merge_state, ) from sgl_kernel.elementwise import ( apply_rope_with_cos_sin_cache_inplace, diff --git a/sgl-kernel/python/sgl_kernel/attention.py b/sgl-kernel/python/sgl_kernel/attention.py index b90834194..b8d6bce75 100644 --- a/sgl-kernel/python/sgl_kernel/attention.py +++ b/sgl-kernel/python/sgl_kernel/attention.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch @@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): ) +def merge_state( + v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) + v_merged = torch.empty_like(v_a) + s_merged = torch.empty_like(s_a) + torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged) + return v_merged, s_merged + + def cutlass_mla_decode( q_nope_and_q_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, @@ -54,7 +67,7 @@ def cutlass_mla_decode( (B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype ) - torch.ops.sgl_kernel.cutlass_mla_decode( + torch.ops.sgl_kernel.cutlass_mla_decode.default( out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace ) return out @@ -63,6 +76,6 @@ def cutlass_mla_decode( def cutlass_mla_get_workspace_size( max_seq_len: int, num_batches: int, sm_count: int = 0 ) -> int: - return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size( + return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default( max_seq_len, num_batches, sm_count ) diff --git a/sgl-kernel/tests/test_merge_state.py b/sgl-kernel/tests/test_merge_state.py new file mode 100644 index 000000000..2931fa949 --- /dev/null +++ b/sgl-kernel/tests/test_merge_state.py @@ -0,0 +1,138 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py + +from typing import List + +import pytest +import torch +import triton +import triton.language as tl +from sgl_kernel import merge_state + + +def check_input(x: torch.Tensor): + assert x.is_cuda, f"{str(x)} must be a CUDA Tensor" + assert x.is_contiguous(), f"{str(x)} must be contiguous" + + +def check_dim(d, x: torch.Tensor): + assert x.dim() == d, f"{str(x)} must be a {d}D tensor" + + +def check_shape(a: torch.Tensor, b: torch.Tensor): + assert a.dim() == b.dim(), "tensors should have same dim" + for i in range(a.dim()): + assert a.size(i) == b.size( + i + ), f"tensors shape mismatch, {a.size()} and {b.size()}" + + +def check_device(tensors: List[torch.Tensor]): + device = tensors[0].device + for t in tensors: + assert ( + t.device == device + ), f"All tensors should be on the same device, but got {device} and {t.device}" + + +@triton.jit +def state_merge(o, m, d, other_o, other_m, other_d): + m_max = tl.maximum(m, other_m) + d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max) + o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max) + return o, m_max, d + + +@triton.jit +def state_normalize(o, m, d): + o = o / d + return o, m, d + + +@triton.jit +def state_get_lse(o, m, d): + return m + tl.log2(d) + + +@triton.jit +def merge_state_kernel( + v_a_ptr, + s_a_ptr, + v_b_ptr, + s_b_ptr, + v_merged_ptr, + s_merged_ptr, + num_heads, + head_dim, + bdx: tl.constexpr, + bdy: tl.constexpr, +): + pos = tl.program_id(axis=0) + for tx in tl.range(bdx): + for head_idx in tl.range(bdy): + s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx) + s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx) + + offsets = (pos * num_heads + head_idx) * head_dim + tx + v_a = tl.load(v_a_ptr + offsets) + v_b = tl.load(v_b_ptr + offsets) + + v_merged, s_max, d = state_merge( + o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1 + ) + v_merged, s_max, d = state_normalize(v_merged, s_max, d) + v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx + tl.store(v_merged_ptr + v_merged_offset, v_merged) + + if s_merged_ptr: + tl.store( + s_merged_ptr + pos * num_heads + head_idx, + tl.log2(d) + s_max, + ) + + +def merge_state_triton( + v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor +): + check_input(v_a) + check_input(s_a) + check_input(v_b) + check_input(s_b) + check_device([v_a, s_a, v_b, s_b]) + check_dim(3, v_a) + check_dim(2, s_a) + check_dim(3, v_b) + check_dim(2, s_b) + check_shape(v_a, v_b) + check_shape(s_a, s_b) + assert v_a.size(0) == s_a.size(0) + assert v_a.size(1) == s_b.size(1) + s_a = s_a.to(torch.float32) + s_b = s_b.to(torch.float32) + seq_len = v_a.size(0) + num_heads = v_a.size(1) + head_dim = v_a.size(2) + v_merged = torch.empty_like(v_a).to(s_a.device) + s_merged = torch.empty((seq_len, num_heads)).to(s_a.device) + bdx = head_dim + bdy = num_heads + + merge_state_kernel[lambda meta: (seq_len,)]( + v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy + ) + + return v_merged, s_merged + + +@pytest.mark.parametrize("seq_len", [2048]) +@pytest.mark.parametrize("num_heads", [32]) +@pytest.mark.parametrize("head_dim", [128]) +def test_merge_state(seq_len, num_heads, head_dim): + va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0") + sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0") + v_merged, s_merged = merge_state_triton(va, sa, vb, sb) + v_merged_std, s_merged_std = merge_state(va, sa, vb, sb) + + assert torch.allclose(v_merged, v_merged_std, atol=1e-2) + assert torch.allclose(s_merged, s_merged_std, atol=1e-2)