[fix] fix cutlass_mla_backend with cuda_graph and add sm_scale for sgl-kernel cutlass_mla (#7184)
This commit is contained in:
@@ -58,7 +58,8 @@ def cutlass_mla_decode(
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_kv_splits: int = -1,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> torch.Tensor:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
@@ -118,13 +119,17 @@ def cutlass_mla_decode(
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1
|
||||
max_seq_len: int,
|
||||
num_batches: int,
|
||||
sm_count: int = 0,
|
||||
num_kv_splits: int = 1, # Set to 1 to avoid cuda_graph issue by default.
|
||||
) -> int:
|
||||
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
||||
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
||||
|
||||
Reference in New Issue
Block a user