Fuse quantize and rope in trtllm_mla MTP (#10779)
This commit is contained in:
@@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
cos_sin_cache: Optional[torch.Tensor] = None,
|
||||
is_neox: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
|
||||
# TODO refactor to avoid code duplication
|
||||
merge_query = q_rope is not None
|
||||
if (
|
||||
self.data_type == torch.float8_e4m3fn
|
||||
) and forward_batch.forward_mode.is_target_verify():
|
||||
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
||||
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
||||
assert all(
|
||||
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
||||
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
||||
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
||||
q,
|
||||
q_rope,
|
||||
k.squeeze(1),
|
||||
k_rope.squeeze(1),
|
||||
forward_batch,
|
||||
cos_sin_cache,
|
||||
is_neox,
|
||||
)
|
||||
merge_query = False
|
||||
|
||||
# Save KV cache if requested
|
||||
if save_kv_cache:
|
||||
assert (
|
||||
@@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
|
||||
if q_rope is not None:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
# TODO refactor to avoid code duplication
|
||||
# Prepare query tensor inline
|
||||
if merge_query:
|
||||
# For FP16 path, we merge the query and rope parts into a single tensor
|
||||
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope_reshaped = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
q = _concat_mla_absorb_q_general(q, q_rope)
|
||||
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
||||
else:
|
||||
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
|
||||
@@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
"""
|
||||
return (
|
||||
self.current_attention_backend == "trtllm_mla"
|
||||
and forward_batch.forward_mode.is_decode_or_idle()
|
||||
and (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
)
|
||||
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user