[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

View File

@@ -36,7 +36,8 @@ void cutlass_mla_decode(
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace) {
torch::Tensor const& workspace,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
@@ -98,6 +99,7 @@ typename T::Fmha::Arguments args_from_options(
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
double sm_scale,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope.device().index();
@@ -115,10 +117,7 @@ typename T::Fmha::Arguments args_from_options(
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int D_non_latent = 128;
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
float scale = float(sm_scale);
using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
@@ -183,11 +182,12 @@ void runMla(
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, num_kv_splits);
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
@@ -215,6 +215,7 @@ void cutlass_mla_decode(
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
@@ -228,13 +229,13 @@ void cutlass_mla_decode(
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, num_kv_splits, stream);
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}

View File

@@ -60,7 +60,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
m.def(
"cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
"page_table, Tensor! workspace, int num_kv_splits) -> ()");
"page_table, Tensor! workspace, float sm_scale, int num_kv_splits) -> ()");
m.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size);