Enables TRT-LLM backend to be used for target_verify (#10281)
Co-authored-by: Pranav Marathe <pranavm@ipp1-3309.ipp1a1.colossus.nvidia.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
@@ -127,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
"disable_chunked_prefix_cache"
|
"disable_chunked_prefix_cache"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
|
||||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||||
"""
|
"""
|
||||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||||
@@ -217,7 +219,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
"""Initialize metadata for CUDA graph capture."""
|
"""Initialize metadata for CUDA graph capture."""
|
||||||
|
|
||||||
# Delegate to parent for non-decode modes.
|
# Delegate to parent for non-decode modes.
|
||||||
if not forward_mode.is_decode_or_idle():
|
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||||
return super().init_forward_metadata_capture_cuda_graph(
|
return super().init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -228,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
spec_info,
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if forward_mode.is_target_verify():
|
||||||
|
seq_lens = seq_lens + self.num_draft_tokens
|
||||||
|
|
||||||
# Custom fast-path for decode/idle.
|
# Custom fast-path for decode/idle.
|
||||||
# Capture with full width so future longer sequences are safe during replay
|
# Capture with full width so future longer sequences are safe during replay
|
||||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||||
@@ -270,7 +275,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
):
|
):
|
||||||
"""Replay CUDA graph with new inputs."""
|
"""Replay CUDA graph with new inputs."""
|
||||||
# Delegate to parent for non-decode modes.
|
# Delegate to parent for non-decode modes.
|
||||||
if not forward_mode.is_decode_or_idle():
|
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||||
return super().init_forward_metadata_replay_cuda_graph(
|
return super().init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -282,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
seq_lens_cpu,
|
seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if forward_mode.is_target_verify():
|
||||||
|
seq_lens = seq_lens + self.num_draft_tokens
|
||||||
|
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
||||||
|
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
|
||||||
# Update block indices for new sequences.
|
# Update block indices for new sequences.
|
||||||
@@ -332,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
cum_seq_lens_q,
|
cum_seq_lens_q,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
)
|
)
|
||||||
elif forward_batch.forward_mode.is_decode_or_idle():
|
elif (
|
||||||
|
forward_batch.forward_mode.is_decode_or_idle()
|
||||||
|
or forward_batch.forward_mode.is_target_verify()
|
||||||
|
):
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
|
|
||||||
# Get maximum sequence length.
|
# Get maximum sequence length.
|
||||||
@@ -341,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
else:
|
else:
|
||||||
max_seq = forward_batch.seq_lens.max().item()
|
max_seq = forward_batch.seq_lens.max().item()
|
||||||
|
|
||||||
|
seq_lens = forward_batch.seq_lens
|
||||||
|
|
||||||
|
if forward_batch.forward_mode.is_target_verify():
|
||||||
|
max_seq = max_seq + self.num_draft_tokens
|
||||||
|
seq_lens = seq_lens + self.num_draft_tokens
|
||||||
|
|
||||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||||
block_kv_indices = self._create_block_kv_indices(
|
block_kv_indices = self._create_block_kv_indices(
|
||||||
bs,
|
bs,
|
||||||
max_seqlen_pad,
|
max_seqlen_pad,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
seq_lens,
|
||||||
forward_batch.seq_lens.device,
|
seq_lens.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_len_val = int(max_seq)
|
max_seq_len_val = int(max_seq)
|
||||||
@@ -553,52 +571,86 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
save_kv_cache: bool = True,
|
save_kv_cache: bool = True,
|
||||||
q_rope: Optional[torch.Tensor] = None,
|
q_rope: Optional[torch.Tensor] = None,
|
||||||
k_rope: Optional[torch.Tensor] = None,
|
k_rope: Optional[torch.Tensor] = None,
|
||||||
):
|
) -> torch.Tensor:
|
||||||
if (
|
if forward_batch.forward_mode.is_draft_extend():
|
||||||
forward_batch.forward_mode.is_target_verify()
|
|
||||||
or forward_batch.forward_mode.is_draft_extend()
|
|
||||||
):
|
|
||||||
return super().forward_extend(
|
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
|
||||||
)
|
|
||||||
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
|
||||||
if forward_batch.attn_attend_prefix_cache is None:
|
|
||||||
return super().forward_extend(
|
return super().forward_extend(
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
if not forward_batch.attn_attend_prefix_cache:
|
# Save KV cache if requested
|
||||||
|
if save_kv_cache:
|
||||||
|
assert (
|
||||||
|
k is not None and k_rope is not None
|
||||||
|
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
||||||
|
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||||
|
layer, forward_batch.out_cache_loc, k, k_rope
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_rope is not None:
|
||||||
|
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||||
|
q_rope = q_rope.view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||||
|
)
|
||||||
|
q = torch.cat([q, q_rope], dim=-1)
|
||||||
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
|
if k_rope is not None:
|
||||||
|
k = torch.cat([k, k_rope], dim=-1)
|
||||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||||
|
|
||||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
|
||||||
|
if forward_batch.forward_mode.is_target_verify():
|
||||||
|
metadata = (
|
||||||
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||||
|
or self.forward_decode_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
||||||
|
bs = forward_batch.batch_size
|
||||||
|
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||||
|
|
||||||
|
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
|
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||||
|
|
||||||
|
q_scale = 1.0
|
||||||
|
k_scale = (
|
||||||
|
layer.k_scale_float
|
||||||
|
if getattr(layer, "k_scale_float", None) is not None
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||||
|
|
||||||
|
seq_lens = (
|
||||||
|
forward_batch.seq_lens.to(torch.int32)
|
||||||
|
+ forward_batch.spec_info.draft_token_num
|
||||||
|
)
|
||||||
|
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
||||||
|
|
||||||
|
# TODO may use `mla_rope_quantize_fp8` fusion
|
||||||
|
q = q.to(self.data_type)
|
||||||
|
assert kv_cache.dtype == self.data_type
|
||||||
|
|
||||||
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
kv_cache=kv_cache,
|
||||||
value=v,
|
|
||||||
workspace_buffer=self.workspace_buffer,
|
workspace_buffer=self.workspace_buffer,
|
||||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
bmm1_scale=layer.scaling,
|
block_tables=metadata.block_kv_indices,
|
||||||
bmm2_scale=1.0,
|
seq_lens=seq_lens,
|
||||||
o_sf_scale=1.0,
|
max_seq_len=max_seq_len,
|
||||||
batch_size=forward_batch.batch_size,
|
bmm1_scale=bmm1_scale,
|
||||||
window_left=-1,
|
|
||||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
|
||||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
|
||||||
enable_pdl=False,
|
|
||||||
is_causal=True,
|
|
||||||
return_lse=forward_batch.mha_return_lse,
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
if not (
|
# Reshape output directly without slicing
|
||||||
forward_batch.attn_attend_prefix_cache is not None
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
and forward_batch.mha_return_lse
|
return output
|
||||||
):
|
|
||||||
output = super().forward_extend(
|
if forward_batch.attn_attend_prefix_cache:
|
||||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# MHA for chunked prefix kv cache when running model with MLA
|
# MHA for chunked prefix kv cache when running model with MLA
|
||||||
assert forward_batch.prefix_chunk_idx is not None
|
assert forward_batch.prefix_chunk_idx is not None
|
||||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||||
@@ -606,11 +658,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
assert k_rope is None
|
assert k_rope is None
|
||||||
chunk_idx = forward_batch.prefix_chunk_idx
|
chunk_idx = forward_batch.prefix_chunk_idx
|
||||||
|
|
||||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
|
||||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
|
||||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
|
||||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
value=v,
|
value=v,
|
||||||
@@ -630,7 +679,26 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
return_lse=True,
|
return_lse=True,
|
||||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||||
)
|
)
|
||||||
return output
|
|
||||||
|
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v,
|
||||||
|
workspace_buffer=self.workspace_buffer,
|
||||||
|
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||||
|
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
|
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||||
|
bmm1_scale=layer.scaling,
|
||||||
|
bmm2_scale=1.0,
|
||||||
|
o_sf_scale=1.0,
|
||||||
|
batch_size=forward_batch.batch_size,
|
||||||
|
window_left=-1,
|
||||||
|
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
|
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||||
|
enable_pdl=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_lse=forward_batch.mha_return_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||||
|
|||||||
Reference in New Issue
Block a user