Add Cutlass MLA attention backend (#5390)

This commit is contained in:
Trevor Morris
2025-04-27 20:58:53 -07:00
committed by GitHub
parent 40d9b8acce
commit 84810da4ae
7 changed files with 305 additions and 3 deletions

View File

@@ -78,6 +78,7 @@ def cutlass_mla_decode(
assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape
assert B_block_table == B_q
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
assert block_num % (128 / PAGE_SIZE) == 0
# TODO(kaixih@nvidia): support fp8
@@ -109,6 +110,8 @@ def cutlass_mla_decode(
def cutlass_mla_get_workspace_size(
max_seq_len: int, num_batches: int, sm_count: int = 0
) -> int:
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count
)