Remove 200us slow concat kernel (part 1: kernel) (#7145)
This commit is contained in:
@@ -52,34 +52,42 @@ def merge_state_v2(
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope_and_q_pe: torch.Tensor,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
num_kv_splits: int = -1,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
q_nope_and_q_pe.ndim == 3
|
||||
), f"q_nope_and_q_pe must be a 3D tensor, but got {q_nope_and_q_pe.ndim}"
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert (
|
||||
kv_c_and_k_pe_cache.ndim == 3
|
||||
), f"kv_c_and_k_pe_cache must be a 3D tensor, but got {kv_c_and_k_pe_cache.ndim}"
|
||||
B_q, H, D_q = q_nope_and_q_pe.shape
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q == D_ckv and D_q == D_latent + D_rope, (
|
||||
f"D_q must be equal to D_ckv and D_q must be equal to D_latent + D_rope, "
|
||||
f"but got D_q = {D_q}, D_ckv = {D_ckv}, D_latent = {D_latent}, D_rope = {D_rope}"
|
||||
)
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_and_q_pe_padded = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_q))
|
||||
q_nope_and_q_pe_padded[:, :H] = q_nope_and_q_pe
|
||||
q_nope_and_q_pe = q_nope_and_q_pe_padded
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
@@ -88,14 +96,11 @@ def cutlass_mla_decode(
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
# TODO(kaixih@nvidia): support fp8
|
||||
assert q_nope_and_q_pe.dtype in (
|
||||
assert q_nope.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"q_nope_and_q_pe.dtype needs to be fp16 or bf16 but got {q_nope_and_q_pe.dtype}."
|
||||
assert kv_c_and_k_pe_cache.dtype == q_nope_and_q_pe.dtype, (
|
||||
f"kv_c_and_k_pe_cache.dtype needs to be the same as q_nope_and_q_pe.dtype, "
|
||||
f"but got {kv_c_and_k_pe_cache.dtype}."
|
||||
)
|
||||
), f"q_nope.dtype needs to be fp16 or bf16 but got {q_nope.dtype}."
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert (
|
||||
seq_lens.dtype == torch.int32
|
||||
), f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
@@ -103,11 +108,12 @@ def cutlass_mla_decode(
|
||||
page_table.dtype == torch.int32
|
||||
), f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
|
||||
out = q_nope_and_q_pe.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent))
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out,
|
||||
q_nope_and_q_pe,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
|
||||
Reference in New Issue
Block a user