upgrade flashinfer v0.2.0.post2 (#3288)

Co-authored-by: pankajroark <pankajroark@users.noreply.github.com>
This commit is contained in:
Yineng Zhang
2025-02-04 21:41:40 +08:00
committed by GitHub
parent 70817a7eae
commit d39899e85c
8 changed files with 42 additions and 51 deletions

View File

@@ -149,6 +149,7 @@ class FlashInferAttnBackend(AttentionBackend):
BatchPrefillWithPagedKVCacheWrapper(
self.workspace_buffer,
"NHD",
backend="fa2",
)
)
self.prefill_wrappers_verify.append(
@@ -313,7 +314,7 @@ class FlashInferAttnBackend(AttentionBackend):
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
)
)
seq_lens_sum = seq_lens.sum().item()
@@ -1155,41 +1156,24 @@ def fast_decode_plan(
self.last_page_len = torch.ones(32768, dtype=torch.int32)
empty_q_data = self.empty_q_data
empty_kv_cache = self.empty_kv_cache
if self.use_tensor_cores:
if not self.is_cuda_graph_enabled:
# when not using cudagraph, we need to create the indptr buffer, otherwise
# the buffer is already created during initialization
self._qo_indptr_buf = torch.arange(
batch_size + 1, dtype=torch.int32, device=indptr.device
)
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._qo_indptr_buf,
indptr,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
empty_q_data,
)
else:
self._wrapper.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
indptr,
self.last_page_len,
batch_size,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
PosEncodingMode[pos_encoding_mode].value,
logits_soft_cap,
empty_q_data,
empty_kv_cache,
)
stream = torch.cuda.current_stream()
self._cached_module.plan(
self._float_workspace_buffer,
self._int_workspace_buffer,
self._pin_memory_int_workspace_buffer,
indptr.to("cpu"),
batch_size,
num_qo_heads,
num_kv_heads,
page_size,
self.is_cuda_graph_enabled,
window_left,
logits_soft_cap,
head_dim,
empty_q_data,
empty_kv_cache,
stream.cuda_stream,
)
self._pos_encoding_mode = pos_encoding_mode
self._window_left = window_left
self._logits_soft_cap = logits_soft_cap