[Fix] fix flashinfer usage for window attention (#1107)

This commit is contained in:
Ying Sheng
2024-08-15 00:53:24 -07:00
committed by GitHub
parent e86b1ccbf0
commit 14cb544d56
3 changed files with 12 additions and 18 deletions

View File

@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
if self.sliding_window_size != -1:
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
if self.sliding_window_size != -1 or self.reuse:
prefill_wrapper_paged = prefill_wrapper_paged[0]
else:
if isinstance(prefill_wrapper_ragged, list):
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
if isinstance(prefill_wrapper_paged, list):
prefill_wrapper_paged = prefill_wrapper_paged[1]