Fuse quantize and rope in trtllm_mla MTP (#10779)
This commit is contained in:
@@ -568,12 +568,35 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
|
cos_sin_cache: Optional[torch.Tensor] = None,
|
||||||
|
is_neox: Optional[bool] = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if forward_batch.forward_mode.is_draft_extend():
|
if forward_batch.forward_mode.is_draft_extend():
|
||||||
return super().forward_extend(
|
return super().forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO refactor to avoid code duplication
|
||||||
|
merge_query = q_rope is not None
|
||||||
|
if (
|
||||||
|
self.data_type == torch.float8_e4m3fn
|
||||||
|
) and forward_batch.forward_mode.is_target_verify():
|
||||||
|
# For FP8 path, we quantize the query and rope parts and merge them into a single tensor
|
||||||
|
# Note: rope application in deepseek_v2.py:forward_absorb_prepare is skipped for FP8 decode path of this trtllm_mla backend
|
||||||
|
assert all(
|
||||||
|
x is not None for x in [q_rope, k_rope, cos_sin_cache]
|
||||||
|
), "For FP8 path and using flashinfer.rope.mla_rope_quantize we need all of q_rope, k_rope and cos_sin_cache to be not None."
|
||||||
|
q, k, k_rope = self.quantize_and_rope_for_fp8(
|
||||||
|
q,
|
||||||
|
q_rope,
|
||||||
|
k.squeeze(1),
|
||||||
|
k_rope.squeeze(1),
|
||||||
|
forward_batch,
|
||||||
|
cos_sin_cache,
|
||||||
|
is_neox,
|
||||||
|
)
|
||||||
|
merge_query = False
|
||||||
|
|
||||||
# Save KV cache if requested
|
# Save KV cache if requested
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
assert (
|
assert (
|
||||||
@@ -583,12 +606,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, k_rope
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
if q_rope is not None:
|
# TODO refactor to avoid code duplication
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
# Prepare query tensor inline
|
||||||
q_rope = q_rope.view(
|
if merge_query:
|
||||||
|
# For FP16 path, we merge the query and rope parts into a single tensor
|
||||||
|
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope_reshaped = q_rope.view(
|
||||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
)
|
)
|
||||||
q = _concat_mla_absorb_q_general(q, q_rope)
|
q = _concat_mla_absorb_q_general(q_nope, q_rope_reshaped)
|
||||||
|
else:
|
||||||
|
# For FP8 path, we already have the query and rope parts merged because of the quantize_and_rope_for_fp8 function
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
|
|||||||
@@ -1399,7 +1399,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return (
|
return (
|
||||||
self.current_attention_backend == "trtllm_mla"
|
self.current_attention_backend == "trtllm_mla"
|
||||||
and forward_batch.forward_mode.is_decode_or_idle()
|
and (
|
||||||
|
forward_batch.forward_mode.is_decode_or_idle()
|
||||||
|
or forward_batch.forward_mode.is_target_verify()
|
||||||
|
)
|
||||||
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user