[Submodule] Change FlashInfer to import (#156)

This commit is contained in:
Cody Yu
2024-02-06 19:28:29 -08:00
committed by GitHub
parent cb8e1982f8
commit 26c3494152
5 changed files with 17 additions and 24 deletions

View File

@@ -98,12 +98,7 @@ class RadixAttention(nn.Module):
o = input_metadata.prefill_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.qo_indptr,
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
allow_fp16_qk_reduction=True,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)
@@ -114,9 +109,6 @@ class RadixAttention(nn.Module):
o = input_metadata.decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
input_metadata.kv_indptr,
input_metadata.kv_indices,
input_metadata.kv_last_page_len,
)
return o.view(-1, self.tp_q_head_num * self.head_dim)

View File

@@ -90,6 +90,11 @@ class InputMetadata:
decode_wrapper = None
def init_flashinfer_args(self, tp_size):
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
self.kv_indptr = torch.zeros(
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
@@ -107,11 +112,7 @@ class InputMetadata:
(self.batch_size,), dtype=torch.int32, device="cuda"
)
from flashinfer.ops import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
)
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda")
if (
self.forward_mode == ForwardMode.PREFILL
or self.forward_mode == ForwardMode.EXTEND
@@ -120,19 +121,21 @@ class InputMetadata:
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
)
self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper()
self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.prefill_wrapper.begin_forward(
self.qo_indptr,
self.batch_size,
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
)
else:
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper()
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
self.decode_wrapper.begin_forward(
self.kv_indptr,
self.kv_indices,
self.kv_last_page_len,
self.batch_size,
self.model_runner.model_config.num_attention_heads // tp_size,
self.model_runner.model_config.num_key_value_heads // tp_size,
self.model_runner.model_config.head_dim,