feat: update model_specific_adjustment (#5344)

Co-authored-by: hebiao064 <hebiaobuaa@gmail.com>
This commit is contained in:
Yineng Zhang
2025-04-15 14:45:15 -07:00
committed by GitHub
parent e8f62b20ca
commit fa909dc3c4
4 changed files with 51 additions and 17 deletions

View File

@@ -383,7 +383,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
elif forward_batch.forward_mode.is_extend_or_draft_extend():
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
metadata.cu_seqlens_k = torch.nn.functional.pad(