[Fix] fix flashinfer usage for window attention (#1107)
This commit is contained in:
@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
|
||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||
if self.sliding_window_size != -1:
|
||||
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
|
||||
if self.sliding_window_size != -1 or self.reuse:
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||
else:
|
||||
if isinstance(prefill_wrapper_ragged, list):
|
||||
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
|
||||
if isinstance(prefill_wrapper_paged, list):
|
||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||
|
||||
|
||||
@@ -324,9 +324,11 @@ def update_flashinfer_indices(
|
||||
else:
|
||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||
for wrapper_id in range(2):
|
||||
if flashinfer_use_ragged:
|
||||
if flashinfer_use_ragged and wrapper_id == 1:
|
||||
# full attention use ragged+paged
|
||||
paged_kernel_lens = prefix_lens
|
||||
else:
|
||||
# window attention use paged only
|
||||
paged_kernel_lens = seq_lens
|
||||
|
||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||
@@ -374,13 +376,9 @@ def update_flashinfer_indices(
|
||||
)
|
||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||
|
||||
if flashinfer_use_ragged:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||
wrapper_id
|
||||
].end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
||||
wrapper_id
|
||||
].begin_forward(
|
||||
if flashinfer_use_ragged and wrapper_id == 1:
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||
qo_indptr,
|
||||
qo_indptr,
|
||||
num_qo_heads,
|
||||
|
||||
@@ -342,15 +342,14 @@ class ModelRunner:
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_ragged = []
|
||||
self.flashinfer_prefill_wrapper_ragged = (
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged = []
|
||||
self.flashinfer_decode_wrapper = []
|
||||
for i in range(2):
|
||||
self.flashinfer_prefill_wrapper_ragged.append(
|
||||
BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
)
|
||||
)
|
||||
self.flashinfer_prefill_wrapper_paged.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.flashinfer_workspace_buffer, "NHD"
|
||||
|
||||
Reference in New Issue
Block a user