[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

View File

@@ -108,7 +108,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
PAGE_SIZE,
)
workspace_size = cutlass_mla_get_workspace_size(
max_seqlen_pad * PAGE_SIZE, bs
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
)
workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
@@ -138,7 +138,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
cuda_graph_kv_indices = block_kv_indices
workspace_size = cutlass_mla_get_workspace_size(
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
)
self.cuda_graph_mla_workspace = torch.empty(
workspace_size, device="cuda", dtype=torch.uint8
@@ -280,6 +280,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens=forward_batch.seq_lens.to(torch.int32),
page_table=self.forward_metadata.block_kv_indices,
workspace=self.forward_metadata.workspace,
num_kv_splits=1,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)