Fix trtllm_mla slow concat kernel in MTP (#10777)
This commit is contained in:
@@ -505,10 +505,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
q_rope_reshaped = q_rope.view(
|
q_rope_reshaped = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
if _is_cuda and q_nope.shape[-1] == 512 and q_rope_reshaped.shape[-1] == 64:
|
query = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
||||||
query = concat_mla_absorb_q(q_nope, q_rope_reshaped)
|
|
||||||
else:
|
|
||||||
query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
|
|
||||||
else:
|
else:
|
||||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||||
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
@@ -591,7 +588,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
q_rope = q_rope.view(
|
q_rope = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
q = torch.cat([q, q_rope], dim=-1)
|
q = _concat_mla_absorb_q_general(q, q_rope)
|
||||||
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
@@ -716,3 +713,10 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
|||||||
kv_indptr_buf=self.kv_indptr[i],
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
q_indptr_decode_buf=self.q_indptr_decode,
|
q_indptr_decode_buf=self.q_indptr_decode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_mla_absorb_q_general(q_nope, q_rope):
|
||||||
|
if _is_cuda and q_nope.shape[-1] == 512 and q_rope.shape[-1] == 64:
|
||||||
|
return concat_mla_absorb_q(q_nope, q_rope)
|
||||||
|
else:
|
||||||
|
return torch.cat([q_nope, q_rope], dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user