[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
This commit is contained in:
@@ -108,7 +108,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
PAGE_SIZE,
|
||||
)
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
max_seqlen_pad * PAGE_SIZE, bs
|
||||
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
||||
)
|
||||
workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
@@ -138,7 +138,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
cuda_graph_kv_indices = block_kv_indices
|
||||
|
||||
workspace_size = cutlass_mla_get_workspace_size(
|
||||
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs
|
||||
cuda_graph_kv_indices.shape[1] * PAGE_SIZE, max_bs, num_kv_splits=1
|
||||
)
|
||||
self.cuda_graph_mla_workspace = torch.empty(
|
||||
workspace_size, device="cuda", dtype=torch.uint8
|
||||
@@ -280,6 +280,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens=forward_batch.seq_lens.to(torch.int32),
|
||||
page_table=self.forward_metadata.block_kv_indices,
|
||||
workspace=self.forward_metadata.workspace,
|
||||
num_kv_splits=1,
|
||||
)
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
Reference in New Issue
Block a user