[Fix] avoid stream sync and torch compile in prefill for fa3 backend (#4932)
This commit is contained in:
@@ -79,7 +79,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
# Precompute page table
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
|
||||
@@ -797,7 +797,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
encoder_lens=None,
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
spec_info=forward_batch.spec_info,
|
||||
seq_lens_cpu=forward_batch.decode_seq_lens_cpu,
|
||||
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||
)
|
||||
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
max_seqlen_pad = triton.cdiv(
|
||||
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
|
||||
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
|
||||
)
|
||||
block_kv_indices = torch.full(
|
||||
(bs, max_seqlen_pad),
|
||||
|
||||
Reference in New Issue
Block a user