Remove 200us slow concat kernel (part 1: kernel) (#7145)

This commit is contained in:
fzyzcjy
2025-06-13 16:58:29 +08:00
committed by GitHub
parent 2f4ec752bc
commit aa46ed34d2
6 changed files with 79 additions and 48 deletions

View File

@@ -22,6 +22,7 @@ limitations under the License.
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
@@ -30,7 +31,8 @@ limitations under the License.
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
@@ -91,16 +93,17 @@ struct MlaSm100 {
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& q_nope_and_q_pe,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope_and_q_pe.device().index();
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope_and_q_pe.sizes()[0];
int batches = q_nope.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
@@ -122,8 +125,11 @@ typename T::Fmha::Arguments args_from_options(
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;
StrideQ stride_Q = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(H * (0 + D_latent + D_rope)));
StrideQ stride_Q_nope = cute::make_tuple(
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
StrideQ stride_Q_pe = cute::make_tuple(
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
StrideK stride_C = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
@@ -133,15 +139,16 @@ typename T::Fmha::Arguments args_from_options(
using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_ptr = static_cast<Element*>(q_nope_and_q_pe.data_ptr());
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale,
Q_ptr,
stride_Q,
Q_ptr + D_latent,
stride_Q,
Q_nope_ptr,
stride_Q_nope,
Q_pe_ptr,
stride_Q_pe,
C_ptr,
stride_C,
C_ptr + D_latent,
@@ -170,7 +177,8 @@ typename T::Fmha::Arguments args_from_options(
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope_and_q_pe,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
@@ -179,7 +187,7 @@ void runMla(
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
@@ -201,15 +209,16 @@ void runMla(
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
int64_t num_kv_splits) {
auto in_dtype = q_nope_and_q_pe.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
@@ -219,13 +228,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}

View File

@@ -59,7 +59,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);