[1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473)
This commit is contained in:
@@ -104,6 +104,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("concat_mla_k(Tensor! k, Tensor k_nope, Tensor k_rope) -> ()");
|
||||
m.impl("concat_mla_k", torch::kCUDA, &concat_mla_k);
|
||||
|
||||
m.def("concat_mla_absorb_q(Tensor a, Tensor b, Tensor! out) -> ()");
|
||||
m.impl("concat_mla_absorb_q", torch::kCUDA, &concat_mla_absorb_q);
|
||||
|
||||
/*
|
||||
* From csrc/gemm
|
||||
*/
|
||||
|
||||
@@ -115,3 +115,105 @@ void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) {
|
||||
cudaError_t err = cudaGetLastError();
|
||||
TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
// ============================== concat_mla_absorb_q ==============================
|
||||
|
||||
// TODO give a name prefix, also maybe refactor code above
|
||||
constexpr int A_LAST_DIM = 512;
|
||||
constexpr int B_LAST_DIM = 64;
|
||||
|
||||
__global__ void concat_mla_absorb_q_kernel(
|
||||
nv_bfloat16* a,
|
||||
nv_bfloat16* b,
|
||||
nv_bfloat16* out,
|
||||
const int num_items,
|
||||
const int dim_1,
|
||||
const int a_stride_0,
|
||||
const int a_stride_1,
|
||||
const int b_stride_0,
|
||||
const int b_stride_1,
|
||||
const int out_stride_0,
|
||||
const int out_stride_1) {
|
||||
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
|
||||
const int lane_id = get_lane_id();
|
||||
|
||||
const int idx_0 = flat_warp_id / dim_1;
|
||||
const int idx_1 = flat_warp_id % dim_1;
|
||||
|
||||
if (flat_warp_id >= num_items) {
|
||||
return;
|
||||
}
|
||||
|
||||
using ABufType = int4;
|
||||
constexpr int A_NUM_UNROLL = 2;
|
||||
static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32);
|
||||
ABufType a_buf[A_NUM_UNROLL];
|
||||
|
||||
using BBufType = int;
|
||||
constexpr int B_NUM_UNROLL = 1;
|
||||
static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32);
|
||||
BBufType b_buf;
|
||||
|
||||
{
|
||||
const BBufType* base_addr = reinterpret_cast<BBufType*>(b + idx_0 * b_stride_0 + idx_1 * b_stride_1);
|
||||
b_buf = *(base_addr + lane_id);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_NUM_UNROLL; ++i) {
|
||||
const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);
|
||||
a_buf[i] = *(base_addr + i * 32 + lane_id);
|
||||
}
|
||||
|
||||
{
|
||||
BBufType* base_addr = reinterpret_cast<BBufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM);
|
||||
*(base_addr + lane_id) = b_buf;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_NUM_UNROLL; ++i) {
|
||||
ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);
|
||||
*(base_addr + i * 32 + lane_id) = a_buf[i];
|
||||
}
|
||||
}
|
||||
|
||||
inline void check_tensor_concat_mla_absorb_q(const at::Tensor& t, int64_t shape2) {
|
||||
TORCH_CHECK_EQ(t.dim(), 3);
|
||||
TORCH_CHECK_EQ(t.size(2), shape2);
|
||||
TORCH_CHECK_EQ(t.stride(2), 1);
|
||||
TORCH_CHECK_EQ(t.dtype(), at::kBFloat16);
|
||||
TORCH_CHECK(t.device().is_cuda());
|
||||
TORCH_CHECK_EQ(((int64_t)t.data_ptr()) % 16, 0); // alignment
|
||||
}
|
||||
|
||||
// TODO further optimize it later
|
||||
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out) {
|
||||
check_tensor_concat_mla_absorb_q(a, A_LAST_DIM);
|
||||
check_tensor_concat_mla_absorb_q(b, B_LAST_DIM);
|
||||
check_tensor_concat_mla_absorb_q(out, A_LAST_DIM + B_LAST_DIM);
|
||||
|
||||
const auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_EQ(a.size(0) * a.size(1), b.size(0) * b.size(1));
|
||||
TORCH_CHECK_EQ(a.size(1), b.size(1));
|
||||
const int num_items = a.size(0) * a.size(1);
|
||||
|
||||
constexpr int num_warps_per_block = 32;
|
||||
const int grid_size = ceil_div(num_items, num_warps_per_block);
|
||||
const int block_size = num_warps_per_block * 32;
|
||||
|
||||
concat_mla_absorb_q_kernel<<<grid_size, block_size, 0, stream>>>(
|
||||
reinterpret_cast<nv_bfloat16*>(a.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(b.data_ptr()),
|
||||
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
|
||||
num_items,
|
||||
a.size(1),
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
out.stride(0),
|
||||
out.stride(1));
|
||||
cudaError_t err = cudaGetLastError();
|
||||
TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err));
|
||||
}
|
||||
|
||||
@@ -172,6 +172,7 @@ void downcast_fp8(
|
||||
|
||||
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
||||
void concat_mla_k(torch::Tensor k, torch::Tensor k_nope, torch::Tensor k_rope);
|
||||
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
void gelu_quick(at::Tensor& out, const at::Tensor& input);
|
||||
|
||||
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
concat_mla_absorb_q,
|
||||
concat_mla_k,
|
||||
copy_to_gpu_no_ce,
|
||||
downcast_fp8,
|
||||
|
||||
@@ -379,3 +379,15 @@ def concat_mla_k(
|
||||
k_rope: torch.Tensor,
|
||||
):
|
||||
torch.ops.sgl_kernel.concat_mla_k(k, k_nope, k_rope)
|
||||
|
||||
|
||||
def concat_mla_absorb_q(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
):
|
||||
*batch_dims, _ = a.shape
|
||||
out = torch.empty(
|
||||
(*batch_dims, a.shape[-1] + b.shape[-1]), device=a.device, dtype=a.dtype
|
||||
)
|
||||
torch.ops.sgl_kernel.concat_mla_absorb_q(a, b, out)
|
||||
return out
|
||||
|
||||
@@ -67,11 +67,8 @@ ALL_MODELS = [
|
||||
ModelCase("openai-community/gpt2"),
|
||||
ModelCase("microsoft/phi-1_5", trust_remote_code=True),
|
||||
ModelCase("adept/persimmon-8b-chat"),
|
||||
|
||||
ModelCase("upstage/SOLAR-10.7B-Instruct-v1.0"),
|
||||
|
||||
ModelCase("inclusionAI/Ling-lite", trust_remote_code=True),
|
||||
|
||||
ModelCase("microsoft/Phi-3-small-8k-instruct", trust_remote_code=True),
|
||||
ModelCase("allenai/OLMo-2-1124-7B-Instruct", skip_long_prompt=True),
|
||||
ModelCase("ibm-granite/granite-3.0-2b-instruct", skip_long_prompt=True),
|
||||
|
||||
Reference in New Issue
Block a user