fix FlashMLA cudagraph config (#4691)
Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
@@ -92,7 +92,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
max_seqlen_pad = triton.cdiv(
|
max_seqlen_pad = triton.cdiv(
|
||||||
forward_batch.seq_lens.max().item(), PAGE_SIZE
|
forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
|
||||||
)
|
)
|
||||||
block_kv_indices = torch.full(
|
block_kv_indices = torch.full(
|
||||||
(bs, max_seqlen_pad),
|
(bs, max_seqlen_pad),
|
||||||
@@ -206,8 +206,10 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
|||||||
):
|
):
|
||||||
|
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
|
assert seq_lens_cpu is not None
|
||||||
seq_lens = seq_lens[:bs]
|
seq_lens = seq_lens[:bs]
|
||||||
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
|
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
|
||||||
create_flashmla_kv_indices_triton[(bs,)](
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
|
|||||||
Reference in New Issue
Block a user