Enables TRT-LLM backend to be used for target_verify (#10281)
Co-authored-by: Pranav Marathe <pranavm@ipp1-3309.ipp1a1.colossus.nvidia.com> Co-authored-by: fzyzcjy <ch271828n@outlook.com>
This commit is contained in:
@@ -127,6 +127,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"disable_chunked_prefix_cache"
|
||||
]
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||
"""
|
||||
Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
|
||||
@@ -217,7 +219,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
@@ -228,6 +230,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
spec_info,
|
||||
)
|
||||
|
||||
if forward_mode.is_target_verify():
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
|
||||
# Custom fast-path for decode/idle.
|
||||
# Capture with full width so future longer sequences are safe during replay
|
||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||
@@ -270,7 +275,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
if not forward_mode.is_decode_or_idle() and not forward_mode.is_target_verify():
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
@@ -282,6 +287,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
if forward_mode.is_target_verify():
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
del seq_lens_sum # not handle "num_draft_tokens" but we do not need it
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
# Update block indices for new sequences.
|
||||
@@ -332,7 +341,10 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
cum_seq_lens_q,
|
||||
seq_lens,
|
||||
)
|
||||
elif forward_batch.forward_mode.is_decode_or_idle():
|
||||
elif (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
or forward_batch.forward_mode.is_target_verify()
|
||||
):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
# Get maximum sequence length.
|
||||
@@ -341,13 +353,19 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
else:
|
||||
max_seq = forward_batch.seq_lens.max().item()
|
||||
|
||||
seq_lens = forward_batch.seq_lens
|
||||
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
max_seq = max_seq + self.num_draft_tokens
|
||||
seq_lens = seq_lens + self.num_draft_tokens
|
||||
|
||||
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
||||
block_kv_indices = self._create_block_kv_indices(
|
||||
bs,
|
||||
max_seqlen_pad,
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens.device,
|
||||
seq_lens,
|
||||
seq_lens.device,
|
||||
)
|
||||
|
||||
max_seq_len_val = int(max_seq)
|
||||
@@ -553,84 +571,134 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
save_kv_cache: bool = True,
|
||||
q_rope: Optional[torch.Tensor] = None,
|
||||
k_rope: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if (
|
||||
forward_batch.forward_mode.is_target_verify()
|
||||
or forward_batch.forward_mode.is_draft_extend()
|
||||
):
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
||||
if forward_batch.attn_attend_prefix_cache is None:
|
||||
) -> torch.Tensor:
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
return super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
|
||||
if not forward_batch.attn_attend_prefix_cache:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
# Save KV cache if requested
|
||||
if save_kv_cache:
|
||||
assert (
|
||||
k is not None and k_rope is not None
|
||||
), "For populating trtllm_mla kv cache, both k_nope and k_rope should be not None."
|
||||
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
|
||||
layer, forward_batch.out_cache_loc, k, k_rope
|
||||
)
|
||||
|
||||
if q_rope is not None:
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
||||
q_rope = q_rope.view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
||||
)
|
||||
q = torch.cat([q, q_rope], dim=-1)
|
||||
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
if k_rope is not None:
|
||||
k = torch.cat([k, k_rope], dim=-1)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
||||
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
||||
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
metadata = (
|
||||
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
||||
or self.forward_decode_metadata
|
||||
)
|
||||
|
||||
# Ensure query has shape [bs, num_draft_tokens, num_q_heads, head_dim]
|
||||
bs = forward_batch.batch_size
|
||||
q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
|
||||
|
||||
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
kv_cache = k_cache.view(-1, self.page_size, self.kv_cache_dim).unsqueeze(1)
|
||||
|
||||
q_scale = 1.0
|
||||
k_scale = (
|
||||
layer.k_scale_float
|
||||
if getattr(layer, "k_scale_float", None) is not None
|
||||
else 1.0
|
||||
)
|
||||
|
||||
bmm1_scale = q_scale * k_scale * layer.scaling
|
||||
|
||||
seq_lens = (
|
||||
forward_batch.seq_lens.to(torch.int32)
|
||||
+ forward_batch.spec_info.draft_token_num
|
||||
)
|
||||
max_seq_len = metadata.max_seq_len + forward_batch.spec_info.draft_token_num
|
||||
|
||||
# TODO may use `mla_rope_quantize_fp8` fusion
|
||||
q = q.to(self.data_type)
|
||||
assert kv_cache.dtype == self.data_type
|
||||
|
||||
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_cache,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=metadata.block_kv_indices,
|
||||
seq_lens=seq_lens,
|
||||
max_seq_len=max_seq_len,
|
||||
bmm1_scale=bmm1_scale,
|
||||
)
|
||||
|
||||
# Reshape output directly without slicing
|
||||
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
return output
|
||||
|
||||
if forward_batch.attn_attend_prefix_cache:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
|
||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
o_sf_scale=-1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||
)
|
||||
else:
|
||||
if not (
|
||||
forward_batch.attn_attend_prefix_cache is not None
|
||||
and forward_batch.mha_return_lse
|
||||
):
|
||||
output = super().forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
||||
)
|
||||
else:
|
||||
# MHA for chunked prefix kv cache when running model with MLA
|
||||
assert forward_batch.prefix_chunk_idx is not None
|
||||
assert forward_batch.prefix_chunk_cu_seq_lens is not None
|
||||
assert q_rope is None
|
||||
assert k_rope is None
|
||||
chunk_idx = forward_batch.prefix_chunk_idx
|
||||
|
||||
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
k = k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype)
|
||||
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype)
|
||||
output_shape = (q.shape[0], layer.tp_q_head_num, layer.v_head_dim)
|
||||
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=forward_batch.prefix_chunk_seq_lens[chunk_idx],
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=-1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
|
||||
enable_pdl=False,
|
||||
is_causal=False,
|
||||
return_lse=True,
|
||||
out=torch.zeros(*output_shape, dtype=q.dtype, device=q.device),
|
||||
)
|
||||
return output
|
||||
return flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
workspace_buffer=self.workspace_buffer,
|
||||
seq_lens=self.forward_prefill_metadata.seq_lens,
|
||||
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
||||
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
||||
bmm1_scale=layer.scaling,
|
||||
bmm2_scale=1.0,
|
||||
o_sf_scale=1.0,
|
||||
batch_size=forward_batch.batch_size,
|
||||
window_left=-1,
|
||||
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
||||
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
||||
enable_pdl=False,
|
||||
is_causal=True,
|
||||
return_lse=forward_batch.mha_return_lse,
|
||||
)
|
||||
|
||||
|
||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
|
||||
Reference in New Issue
Block a user