feat: adapt merge_state (#5337)
This commit is contained in:
8
.github/workflows/pr-test-sgl-kernel.yml
vendored
8
.github/workflows/pr-test-sgl-kernel.yml
vendored
@@ -44,6 +44,12 @@ jobs:
|
||||
cuda-version: '12.8'
|
||||
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
|
||||
steps:
|
||||
- name: Skip unnecessary builds on push to main
|
||||
if: github.event_name == 'push' && (matrix.cuda-version == '11.8' || matrix.cuda-version == '12.8')
|
||||
run: |
|
||||
echo "Skipping CUDA ${{ matrix.cuda-version }} build on push to main"
|
||||
exit 0
|
||||
|
||||
- name: Cleanup
|
||||
run: |
|
||||
sudo rm -rf $GITHUB_WORKSPACE/* || true
|
||||
@@ -87,7 +93,7 @@ jobs:
|
||||
- name: Install
|
||||
run: |
|
||||
bash scripts/ci_install_dependency.sh
|
||||
pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.7.2
|
||||
pip3 install torch==2.5.1 && pip3 install pytest
|
||||
pip3 uninstall sgl-kernel -y || true
|
||||
pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps
|
||||
pip3 list | grep sgl-kernel
|
||||
|
||||
@@ -25,6 +25,8 @@ find_package(Torch REQUIRED)
|
||||
# clean Torch Flag
|
||||
clear_cuda_arches(CMAKE_FLAG)
|
||||
|
||||
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
|
||||
|
||||
include(FetchContent)
|
||||
|
||||
# cutlass
|
||||
@@ -104,6 +106,7 @@ set(SGL_KERNEL_CUDA_FLAGS
|
||||
"--expt-relaxed-constexpr"
|
||||
"-Xcompiler=-Wconversion"
|
||||
"-Xcompiler=-fno-strict-aliasing"
|
||||
"--threads=16"
|
||||
)
|
||||
|
||||
option(SGL_KERNEL_ENABLE_SM100A "Enable SM100A" OFF)
|
||||
@@ -160,6 +163,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
|
||||
|
||||
set(SOURCES
|
||||
"csrc/allreduce/custom_all_reduce.cu"
|
||||
"csrc/attention/cascade.cu"
|
||||
"csrc/attention/cutlass_mla_kernel.cu"
|
||||
"csrc/attention/lightning_attention_decode_kernel.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
|
||||
55
sgl-kernel/csrc/attention/cascade.cu
Normal file
55
sgl-kernel/csrc/attention/cascade.cu
Normal file
@@ -0,0 +1,55 @@
|
||||
// Adapted from
|
||||
// https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/csrc/cascade.cu
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <flashinfer/attention/cascade.cuh>
|
||||
|
||||
#include "pytorch_extension_utils.h"
|
||||
|
||||
using namespace flashinfer;
|
||||
|
||||
void merge_state(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
|
||||
CHECK_INPUT(v_a);
|
||||
CHECK_INPUT(s_a);
|
||||
CHECK_INPUT(v_b);
|
||||
CHECK_INPUT(s_b);
|
||||
auto device = v_a.device();
|
||||
CHECK_EQ(s_a.device(), device);
|
||||
CHECK_EQ(v_b.device(), device);
|
||||
CHECK_EQ(s_b.device(), device);
|
||||
CHECK_DIM(3, v_a);
|
||||
CHECK_DIM(2, s_a);
|
||||
CHECK_DIM(3, v_b);
|
||||
CHECK_DIM(2, s_b);
|
||||
CHECK_SHAPE(v_a, v_b);
|
||||
CHECK_SHAPE(s_a, s_b);
|
||||
CHECK_EQ(v_a.size(0), s_a.size(0));
|
||||
CHECK_EQ(v_a.size(1), s_b.size(1));
|
||||
unsigned int seq_len = v_a.size(0);
|
||||
unsigned int num_heads = v_a.size(1);
|
||||
unsigned int head_dim = v_a.size(2);
|
||||
|
||||
const c10::cuda::OptionalCUDAGuard device_guard(v_a.device());
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
bool success = DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(v_a.scalar_type(), c_type, [&] {
|
||||
cudaError_t status = MergeState(
|
||||
static_cast<c_type*>(v_a.data_ptr()),
|
||||
static_cast<float*>(s_a.data_ptr()),
|
||||
static_cast<c_type*>(v_b.data_ptr()),
|
||||
static_cast<float*>(s_b.data_ptr()),
|
||||
static_cast<c_type*>(v_merged.data_ptr()),
|
||||
static_cast<float*>(s_merged.data_ptr()),
|
||||
seq_len,
|
||||
num_heads,
|
||||
head_dim,
|
||||
stream);
|
||||
TORCH_CHECK(status == cudaSuccess, "MergeState kernel launch failed: ", cudaGetErrorString(status));
|
||||
return true;
|
||||
});
|
||||
|
||||
TORCH_CHECK(success, "MergeState kernel launch failed: unsupported data type");
|
||||
}
|
||||
@@ -45,6 +45,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
|
||||
"new_kv) -> ()");
|
||||
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
|
||||
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
|
||||
m.impl("merge_state", torch::kCUDA, &merge_state);
|
||||
m.def(
|
||||
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
|
||||
"page_table, Tensor workspace) -> ()");
|
||||
|
||||
@@ -87,6 +87,8 @@ void lightning_attention_decode(
|
||||
const torch::Tensor& slope,
|
||||
torch::Tensor output,
|
||||
torch::Tensor new_kv);
|
||||
void merge_state(
|
||||
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
|
||||
void cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope_and_q_pe,
|
||||
|
||||
@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
|
||||
cutlass_mla_decode,
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
)
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
)
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
v_merged = torch.empty_like(v_a)
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope_and_q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
@@ -54,7 +67,7 @@ def cutlass_mla_decode(
|
||||
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode(
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
|
||||
)
|
||||
return out
|
||||
@@ -63,6 +76,6 @@ def cutlass_mla_decode(
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size(
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
||||
max_seq_len, num_batches, sm_count
|
||||
)
|
||||
|
||||
138
sgl-kernel/tests/test_merge_state.py
Normal file
138
sgl-kernel/tests/test_merge_state.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/55576c626421b5ee7e7ebe74afd26465c8ae863f/flashinfer/triton/kernels/cascade.py
|
||||
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel import merge_state
|
||||
|
||||
|
||||
def check_input(x: torch.Tensor):
|
||||
assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
|
||||
assert x.is_contiguous(), f"{str(x)} must be contiguous"
|
||||
|
||||
|
||||
def check_dim(d, x: torch.Tensor):
|
||||
assert x.dim() == d, f"{str(x)} must be a {d}D tensor"
|
||||
|
||||
|
||||
def check_shape(a: torch.Tensor, b: torch.Tensor):
|
||||
assert a.dim() == b.dim(), "tensors should have same dim"
|
||||
for i in range(a.dim()):
|
||||
assert a.size(i) == b.size(
|
||||
i
|
||||
), f"tensors shape mismatch, {a.size()} and {b.size()}"
|
||||
|
||||
|
||||
def check_device(tensors: List[torch.Tensor]):
|
||||
device = tensors[0].device
|
||||
for t in tensors:
|
||||
assert (
|
||||
t.device == device
|
||||
), f"All tensors should be on the same device, but got {device} and {t.device}"
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_merge(o, m, d, other_o, other_m, other_d):
|
||||
m_max = tl.maximum(m, other_m)
|
||||
d = d * tl.exp2(m - m_max) + other_d * tl.exp2(other_m - m_max)
|
||||
o = o * tl.exp2(m - m_max) + other_o * tl.exp2(other_m - m_max)
|
||||
return o, m_max, d
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_normalize(o, m, d):
|
||||
o = o / d
|
||||
return o, m, d
|
||||
|
||||
|
||||
@triton.jit
|
||||
def state_get_lse(o, m, d):
|
||||
return m + tl.log2(d)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def merge_state_kernel(
|
||||
v_a_ptr,
|
||||
s_a_ptr,
|
||||
v_b_ptr,
|
||||
s_b_ptr,
|
||||
v_merged_ptr,
|
||||
s_merged_ptr,
|
||||
num_heads,
|
||||
head_dim,
|
||||
bdx: tl.constexpr,
|
||||
bdy: tl.constexpr,
|
||||
):
|
||||
pos = tl.program_id(axis=0)
|
||||
for tx in tl.range(bdx):
|
||||
for head_idx in tl.range(bdy):
|
||||
s_a_val = tl.load(s_a_ptr + pos * num_heads + head_idx)
|
||||
s_b_val = tl.load(s_b_ptr + pos * num_heads + head_idx)
|
||||
|
||||
offsets = (pos * num_heads + head_idx) * head_dim + tx
|
||||
v_a = tl.load(v_a_ptr + offsets)
|
||||
v_b = tl.load(v_b_ptr + offsets)
|
||||
|
||||
v_merged, s_max, d = state_merge(
|
||||
o=v_a, m=s_a_val, d=1, other_o=v_b, other_m=s_b_val, other_d=1
|
||||
)
|
||||
v_merged, s_max, d = state_normalize(v_merged, s_max, d)
|
||||
v_merged_offset = (pos * num_heads + head_idx) * head_dim + tx
|
||||
tl.store(v_merged_ptr + v_merged_offset, v_merged)
|
||||
|
||||
if s_merged_ptr:
|
||||
tl.store(
|
||||
s_merged_ptr + pos * num_heads + head_idx,
|
||||
tl.log2(d) + s_max,
|
||||
)
|
||||
|
||||
|
||||
def merge_state_triton(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
):
|
||||
check_input(v_a)
|
||||
check_input(s_a)
|
||||
check_input(v_b)
|
||||
check_input(s_b)
|
||||
check_device([v_a, s_a, v_b, s_b])
|
||||
check_dim(3, v_a)
|
||||
check_dim(2, s_a)
|
||||
check_dim(3, v_b)
|
||||
check_dim(2, s_b)
|
||||
check_shape(v_a, v_b)
|
||||
check_shape(s_a, s_b)
|
||||
assert v_a.size(0) == s_a.size(0)
|
||||
assert v_a.size(1) == s_b.size(1)
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
seq_len = v_a.size(0)
|
||||
num_heads = v_a.size(1)
|
||||
head_dim = v_a.size(2)
|
||||
v_merged = torch.empty_like(v_a).to(s_a.device)
|
||||
s_merged = torch.empty((seq_len, num_heads)).to(s_a.device)
|
||||
bdx = head_dim
|
||||
bdy = num_heads
|
||||
|
||||
merge_state_kernel[lambda meta: (seq_len,)](
|
||||
v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim, bdx=bdx, bdy=bdy
|
||||
)
|
||||
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seq_len", [2048])
|
||||
@pytest.mark.parametrize("num_heads", [32])
|
||||
@pytest.mark.parametrize("head_dim", [128])
|
||||
def test_merge_state(seq_len, num_heads, head_dim):
|
||||
va = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
||||
sa = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
||||
vb = torch.randn(seq_len, num_heads, head_dim).half().to("cuda:0")
|
||||
sb = torch.randn(seq_len, num_heads, dtype=torch.float32).to("cuda:0")
|
||||
v_merged, s_merged = merge_state_triton(va, sa, vb, sb)
|
||||
v_merged_std, s_merged_std = merge_state(va, sa, vb, sb)
|
||||
|
||||
assert torch.allclose(v_merged, v_merged_std, atol=1e-2)
|
||||
assert torch.allclose(s_merged, s_merged_std, atol=1e-2)
|
||||
Reference in New Issue
Block a user