[Fix] fix flashinfer usage for window attention (#1107)
This commit is contained in:
@@ -120,12 +120,9 @@ class RadixAttention(nn.Module):
|
|||||||
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
# using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||||
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged
|
||||||
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
|
||||||
if self.sliding_window_size != -1:
|
if self.sliding_window_size != -1 or self.reuse:
|
||||||
prefill_wrapper_ragged = prefill_wrapper_ragged[0]
|
|
||||||
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
prefill_wrapper_paged = prefill_wrapper_paged[0]
|
||||||
else:
|
else:
|
||||||
if isinstance(prefill_wrapper_ragged, list):
|
|
||||||
prefill_wrapper_ragged = prefill_wrapper_ragged[1]
|
|
||||||
if isinstance(prefill_wrapper_paged, list):
|
if isinstance(prefill_wrapper_paged, list):
|
||||||
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
prefill_wrapper_paged = prefill_wrapper_paged[1]
|
||||||
|
|
||||||
|
|||||||
@@ -324,9 +324,11 @@ def update_flashinfer_indices(
|
|||||||
else:
|
else:
|
||||||
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if flashinfer_use_ragged:
|
if flashinfer_use_ragged and wrapper_id == 1:
|
||||||
|
# full attention use ragged+paged
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
else:
|
else:
|
||||||
|
# window attention use paged only
|
||||||
paged_kernel_lens = seq_lens
|
paged_kernel_lens = seq_lens
|
||||||
|
|
||||||
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
if wrapper_id == 0 and forward_mode == ForwardMode.DECODE:
|
||||||
@@ -374,13 +376,9 @@ def update_flashinfer_indices(
|
|||||||
)
|
)
|
||||||
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
||||||
|
|
||||||
if flashinfer_use_ragged:
|
if flashinfer_use_ragged and wrapper_id == 1:
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
||||||
wrapper_id
|
model_runner.flashinfer_prefill_wrapper_ragged.begin_forward(
|
||||||
].end_forward()
|
|
||||||
model_runner.flashinfer_prefill_wrapper_ragged[
|
|
||||||
wrapper_id
|
|
||||||
].begin_forward(
|
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
num_qo_heads,
|
num_qo_heads,
|
||||||
|
|||||||
@@ -342,15 +342,14 @@ class ModelRunner:
|
|||||||
dtype=torch.uint8,
|
dtype=torch.uint8,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
self.flashinfer_prefill_wrapper_ragged = []
|
self.flashinfer_prefill_wrapper_ragged = (
|
||||||
|
BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
|
self.flashinfer_workspace_buffer, "NHD"
|
||||||
|
)
|
||||||
|
)
|
||||||
self.flashinfer_prefill_wrapper_paged = []
|
self.flashinfer_prefill_wrapper_paged = []
|
||||||
self.flashinfer_decode_wrapper = []
|
self.flashinfer_decode_wrapper = []
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
self.flashinfer_prefill_wrapper_ragged.append(
|
|
||||||
BatchPrefillWithRaggedKVCacheWrapper(
|
|
||||||
self.flashinfer_workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.flashinfer_prefill_wrapper_paged.append(
|
self.flashinfer_prefill_wrapper_paged.append(
|
||||||
BatchPrefillWithPagedKVCacheWrapper(
|
BatchPrefillWithPagedKVCacheWrapper(
|
||||||
self.flashinfer_workspace_buffer, "NHD"
|
self.flashinfer_workspace_buffer, "NHD"
|
||||||
|
|||||||
Reference in New Issue
Block a user