[perf][sgl-kernel] extend cutlass_mla_decode to support num_head < 128 (#6929)
This commit is contained in:
@@ -57,6 +57,7 @@ def cutlass_mla_decode(
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_kv_splits: int = -1,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
q_nope_and_q_pe.ndim == 3
|
||||
@@ -73,7 +74,12 @@ def cutlass_mla_decode(
|
||||
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
|
||||
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
|
||||
)
|
||||
assert H == 128, f"H must be 128, but got {H}"
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q))
|
||||
q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
|
||||
q_nope_and_q_pe = q_nope_and_q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
@@ -97,21 +103,25 @@ def cutlass_mla_decode(
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = torch.empty(
|
||||
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
|
||||
)
|
||||
out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
|
||||
out,
|
||||
q_nope_and_q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
num_kv_splits,
|
||||
)
|
||||
return out
|
||||
return out[:, :H].contiguous()
|
||||
|
||||
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0, num_kv_splits: int = -1
|
||||
) -> int:
|
||||
assert max_seq_len > 0, f"max_seq_len must be greater than 0, got {max_seq_len}"
|
||||
assert num_batches > 0, f"num_batches must be greater than 0, got {num_batches}"
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
||||
max_seq_len, num_batches, sm_count
|
||||
max_seq_len, num_batches, sm_count, num_kv_splits
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user