[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

8
sgl-kernel/include/sgl_kernel_ops.h Executable file → Normal file
View File

@@ -111,9 +111,13 @@ void cutlass_mla_decode(
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
int64_t num_kv_splits = -1);
double sm_scale,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
int64_t cutlass_mla_get_workspace_size(
int64_t max_seq_len, int64_t num_batches, int64_t sm_count = 0, int64_t num_kv_splits = -1);
int64_t max_seq_len,
int64_t num_batches,
int64_t sm_count = 0,
int64_t num_kv_splits = 1 /* Set to 1 to avoid cuda_graph issue by default. */);
/*
* From csrc/elementwise
*/