[1/2] Speed up trtllm_mla attention backend (>10% e2e) (#10473)

This commit is contained in:
fzyzcjy
2025-09-16 02:53:21 +08:00
committed by GitHub
parent 5c08d7d21d
commit 3b25dc127a
6 changed files with 119 additions and 3 deletions

View File

@@ -115,3 +115,105 @@ void concat_mla_k(at::Tensor k, at::Tensor k_nope, at::Tensor k_rope) {
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err));
}
// ============================== concat_mla_absorb_q ==============================
// TODO give a name prefix, also maybe refactor code above
constexpr int A_LAST_DIM = 512;
constexpr int B_LAST_DIM = 64;
__global__ void concat_mla_absorb_q_kernel(
nv_bfloat16* a,
nv_bfloat16* b,
nv_bfloat16* out,
const int num_items,
const int dim_1,
const int a_stride_0,
const int a_stride_1,
const int b_stride_0,
const int b_stride_1,
const int out_stride_0,
const int out_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
const int lane_id = get_lane_id();
const int idx_0 = flat_warp_id / dim_1;
const int idx_1 = flat_warp_id % dim_1;
if (flat_warp_id >= num_items) {
return;
}
using ABufType = int4;
constexpr int A_NUM_UNROLL = 2;
static_assert(sizeof(ABufType) * A_NUM_UNROLL == A_LAST_DIM * sizeof(a[0]) / 32);
ABufType a_buf[A_NUM_UNROLL];
using BBufType = int;
constexpr int B_NUM_UNROLL = 1;
static_assert(sizeof(BBufType) * B_NUM_UNROLL == B_LAST_DIM * sizeof(b[0]) / 32);
BBufType b_buf;
{
const BBufType* base_addr = reinterpret_cast<BBufType*>(b + idx_0 * b_stride_0 + idx_1 * b_stride_1);
b_buf = *(base_addr + lane_id);
}
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
const ABufType* base_addr = reinterpret_cast<ABufType*>(a + idx_0 * a_stride_0 + idx_1 * a_stride_1);
a_buf[i] = *(base_addr + i * 32 + lane_id);
}
{
BBufType* base_addr = reinterpret_cast<BBufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1 + A_LAST_DIM);
*(base_addr + lane_id) = b_buf;
}
#pragma unroll
for (int i = 0; i < A_NUM_UNROLL; ++i) {
ABufType* base_addr = reinterpret_cast<ABufType*>(out + idx_0 * out_stride_0 + idx_1 * out_stride_1);
*(base_addr + i * 32 + lane_id) = a_buf[i];
}
}
inline void check_tensor_concat_mla_absorb_q(const at::Tensor& t, int64_t shape2) {
TORCH_CHECK_EQ(t.dim(), 3);
TORCH_CHECK_EQ(t.size(2), shape2);
TORCH_CHECK_EQ(t.stride(2), 1);
TORCH_CHECK_EQ(t.dtype(), at::kBFloat16);
TORCH_CHECK(t.device().is_cuda());
TORCH_CHECK_EQ(((int64_t)t.data_ptr()) % 16, 0); // alignment
}
// TODO further optimize it later
void concat_mla_absorb_q(at::Tensor a, at::Tensor b, at::Tensor out) {
check_tensor_concat_mla_absorb_q(a, A_LAST_DIM);
check_tensor_concat_mla_absorb_q(b, B_LAST_DIM);
check_tensor_concat_mla_absorb_q(out, A_LAST_DIM + B_LAST_DIM);
const auto stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(a.size(0) * a.size(1), b.size(0) * b.size(1));
TORCH_CHECK_EQ(a.size(1), b.size(1));
const int num_items = a.size(0) * a.size(1);
constexpr int num_warps_per_block = 32;
const int grid_size = ceil_div(num_items, num_warps_per_block);
const int block_size = num_warps_per_block * 32;
concat_mla_absorb_q_kernel<<<grid_size, block_size, 0, stream>>>(
reinterpret_cast<nv_bfloat16*>(a.data_ptr()),
reinterpret_cast<nv_bfloat16*>(b.data_ptr()),
reinterpret_cast<nv_bfloat16*>(out.data_ptr()),
num_items,
a.size(1),
a.stride(0),
a.stride(1),
b.stride(0),
b.stride(1),
out.stride(0),
out.stride(1));
cudaError_t err = cudaGetLastError();
TORCH_CHECK(err == cudaSuccess, "CUDA kernel launch failed: ", cudaGetErrorString(err));
}