[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)

This commit is contained in:
JieXin Liang
2025-06-09 10:37:34 +08:00
committed by GitHub
parent de1350ea20
commit 18efb5e8e0
10 changed files with 2959 additions and 37 deletions

View File

@@ -22,8 +22,9 @@ limitations under the License.
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <device/sm100_mla.hpp>
#include <kernel/sm100_mla_tile_scheduler.hpp>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
@@ -55,7 +56,7 @@ struct IsPersistent {
static const bool value = v;
};
template <typename T, typename PersistenceOption = IsPersistent<true>>
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
@@ -83,7 +84,7 @@ struct MlaSm100 {
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/true>;
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};
@@ -93,7 +94,8 @@ typename T::Fmha::Arguments args_from_options(
at::Tensor const& q_nope_and_q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table) {
at::Tensor const& page_table,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope_and_q_pe.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
@@ -154,8 +156,8 @@ typename T::Fmha::Arguments args_from_options(
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
1, // split_kv
nullptr, // is_var_split_kv
num_kv_splits, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
@@ -165,7 +167,7 @@ typename T::Fmha::Arguments args_from_options(
return arguments;
}
template <typename Element>
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope_and_q_pe,
@@ -173,10 +175,11 @@ void runMla(
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element>;
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table);
auto arguments = args_from_options<MlaSm100Type>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
@@ -185,31 +188,57 @@ void runMla(
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace) {
torch::Tensor const& workspace,
int64_t num_kv_splits) {
auto in_dtype = q_nope_and_q_pe.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope_and_q_pe.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope_and_q_pe.get_device());
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t>(out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t>;
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
@@ -220,6 +249,7 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = num_kv_splits;
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);