Fix sliding window attention and gemma-2 unit tests in CI (#1746)

This commit is contained in:
Lianmin Zheng
2024-10-21 13:47:12 -07:00
committed by GitHub
parent e68b9e7667
commit 00611286a1
4 changed files with 35 additions and 14 deletions

View File

@@ -342,23 +342,25 @@ class FlashInferIndicesUpdaterDecode:
for wrapper_id in range(2):
if wrapper_id == 0:
# Sliding window attention
paged_kernel_lens = torch.minimum( # TODO: replace this with clamp
paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
seq_lens,
torch.tensor(self.sliding_window_size + 1),
)
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
else:
# Full attention
paged_kernel_lens = seq_lens
kv_start_idx = seq_lens - paged_kernel_lens
paged_kernel_lens_tmp = seq_lens
paged_kernel_lens_sum_tmp = seq_lens_sum
kv_start_idx_tmp = None
self.call_begin_forward(
decode_wrappers[wrapper_id],
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
paged_kernel_lens_tmp,
paged_kernel_lens_sum_tmp,
self.kv_indptr[wrapper_id],
kv_start_idx,
kv_start_idx_tmp,
)
def update_cross_attention(self):
@@ -369,14 +371,16 @@ class FlashInferIndicesUpdaterDecode:
wrapper,
req_pool_indices,
paged_kernel_lens,
seq_lens_sum,
paged_kernel_lens_sum,
kv_indptr,
kv_start_idx,
):
bs = len(req_pool_indices)
kv_indptr = kv_indptr[: bs + 1]
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty(seq_lens_sum, dtype=torch.int32, device="cuda")
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,