[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)

This commit is contained in:
JieXin Liang
2025-06-15 03:45:41 +08:00
committed by GitHub
parent ed54bf9d19
commit ab1a4fa5cb
7 changed files with 29 additions and 17 deletions

View File

@@ -58,7 +58,8 @@ def cutlass_mla_decode(
seq_lens: torch.Tensor,
page_table: torch.Tensor,
workspace: torch.Tensor,
num_kv_splits: int = -1,
sm_scale: float,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> torch.Tensor:
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
@@ -118,13 +119,17 @@ def cutlass_mla_decode(
seq_lens,
page_table,
workspace,
sm_scale,
num_kv_splits,
)
return out[:, :H].contiguous()
def cutlass_mla_get_workspace_size(
max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1
max_seq_len: int,
num_batches: int,
sm_count: int = 0,
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
) -> int:
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"