Dispatch flashinfer wrappers (#1550)
This commit is contained in:
@@ -53,39 +53,44 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
if model_runner.sliding_window_size is None:
|
if model_runner.sliding_window_size is not None:
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.num_wrappers = 2
|
||||||
self.workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
|
||||||
self.workspace_buffer, "NHD"
|
|
||||||
)
|
|
||||||
self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Two wrappers: one for sliding window attention and one for full attention.
|
self.num_wrappers = 1
|
||||||
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
|
||||||
self.prefill_wrapper_ragged = None
|
# NOTE: we do not use ragged attention when there are multiple wrappers
|
||||||
self.prefill_wrapper_paged = []
|
self.prefill_wrapper_ragged = (
|
||||||
self.decode_wrapper = []
|
BatchPrefillWithRaggedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
for _ in range(2):
|
if self.num_wrappers == 1
|
||||||
self.prefill_wrapper_paged.append(
|
else None
|
||||||
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
)
|
||||||
)
|
|
||||||
self.decode_wrapper.append(
|
# Two wrappers: one for sliding window attention and one for full attention.
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
|
||||||
self.workspace_buffer,
|
self.prefill_wrappers_paged = []
|
||||||
"NHD",
|
self.decode_wrappers = []
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
for _ in range(self.num_wrappers):
|
||||||
)
|
self.prefill_wrappers_paged.append(
|
||||||
|
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
|
||||||
|
)
|
||||||
|
self.decode_wrappers.append(
|
||||||
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
|
self.workspace_buffer,
|
||||||
|
"NHD",
|
||||||
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.forward_metadata = None
|
self.forward_metadata = None
|
||||||
self.cuda_graph_metadata = {}
|
self.cuda_graph_metadata = {}
|
||||||
|
|
||||||
|
def _get_wrapper_idx(self, layer: nn.Module):
|
||||||
|
if self.num_wrappers == 1:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# TODO: make sure the idx is related to sliding window size
|
||||||
|
return layer.sliding_window_size == -1
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode():
|
||||||
prefix_lens = None
|
prefix_lens = None
|
||||||
@@ -99,7 +104,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged = False
|
use_ragged = False
|
||||||
if (
|
if (
|
||||||
torch.sum(forward_batch.seq_lens).item() >= 4096
|
torch.sum(forward_batch.seq_lens).item() >= 4096
|
||||||
and self.model_runner.sliding_window_size is None
|
and self.num_wrappers == 1
|
||||||
):
|
):
|
||||||
use_ragged = True
|
use_ragged = True
|
||||||
|
|
||||||
@@ -119,7 +124,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
use_ragged,
|
use_ragged,
|
||||||
extend_no_prefix,
|
extend_no_prefix,
|
||||||
total_num_tokens,
|
total_num_tokens,
|
||||||
self.decode_wrapper,
|
self.decode_wrappers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
@@ -135,45 +140,30 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
(max_bs,), dtype=torch.int32, device="cuda"
|
(max_bs,), dtype=torch.int32, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model_runner.sliding_window_size is not None:
|
# NOTE: the buffers are always in the form of list
|
||||||
self.cuda_graph_kv_indptr = [
|
self.cuda_graph_kv_indptr = [self.cuda_graph_kv_indptr] + [
|
||||||
self.cuda_graph_kv_indptr,
|
self.cuda_graph_kv_indptr.clone() for _ in range(self.num_wrappers - 1)
|
||||||
self.cuda_graph_kv_indptr.clone(),
|
]
|
||||||
]
|
self.cuda_graph_kv_indices = [self.cuda_graph_kv_indices] + [
|
||||||
self.cuda_graph_kv_indices = [
|
self.cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
|
||||||
self.cuda_graph_kv_indices,
|
]
|
||||||
self.cuda_graph_kv_indices.clone(),
|
|
||||||
]
|
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
):
|
):
|
||||||
if self.model_runner.sliding_window_size is None:
|
decode_wrappers = []
|
||||||
decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
for i in range(self.num_wrappers):
|
||||||
self.workspace_buffer,
|
decode_wrappers.append(
|
||||||
"NHD",
|
BatchDecodeWithPagedKVCacheWrapper(
|
||||||
use_cuda_graph=True,
|
self.workspace_buffer,
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
"NHD",
|
||||||
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1],
|
use_cuda_graph=True,
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices,
|
use_tensor_cores=self.decode_use_tensor_cores,
|
||||||
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
||||||
)
|
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
||||||
else:
|
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs],
|
||||||
decode_wrapper = []
|
|
||||||
for i in range(2):
|
|
||||||
decode_wrapper.append(
|
|
||||||
BatchDecodeWithPagedKVCacheWrapper(
|
|
||||||
self.workspace_buffer,
|
|
||||||
"NHD",
|
|
||||||
use_cuda_graph=True,
|
|
||||||
use_tensor_cores=self.decode_use_tensor_cores,
|
|
||||||
paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1],
|
|
||||||
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
|
|
||||||
paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[
|
|
||||||
:bs
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
update_flashinfer_indices(
|
update_flashinfer_indices(
|
||||||
ForwardMode.DECODE,
|
ForwardMode.DECODE,
|
||||||
@@ -181,12 +171,12 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
None,
|
None,
|
||||||
decode_wrapper,
|
decode_wrappers,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cuda_graph_metadata[bs] = decode_wrapper
|
self.cuda_graph_metadata[bs] = decode_wrappers
|
||||||
|
|
||||||
self.forward_metadata = (False, False, None, decode_wrapper)
|
self.forward_metadata = (False, False, None, decode_wrappers)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self, bs: int, req_pool_indices, seq_lens
|
self, bs: int, req_pool_indices, seq_lens
|
||||||
@@ -204,17 +194,11 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||||
if not isinstance(self.prefill_wrapper_paged, list):
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
||||||
prefill_wrapper_paged = self.prefill_wrapper_paged
|
self._get_wrapper_idx(layer)
|
||||||
else:
|
]
|
||||||
if layer.sliding_window_size != -1:
|
|
||||||
prefill_wrapper_paged = self.prefill_wrapper_paged[0]
|
|
||||||
else:
|
|
||||||
prefill_wrapper_paged = self.prefill_wrapper_paged[1]
|
|
||||||
|
|
||||||
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
|
use_ragged, extend_no_prefix, _, _ = self.forward_metadata
|
||||||
self.forward_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
if not use_ragged:
|
if not use_ragged:
|
||||||
if k is not None:
|
if k is not None:
|
||||||
@@ -260,15 +244,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
||||||
|
|
||||||
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
|
||||||
use_ragged, extend_no_prefix, total_num_tokens, decode_wrapper = (
|
decode_wrapper = self.forward_metadata[-1][self._get_wrapper_idx(layer)]
|
||||||
self.forward_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(decode_wrapper, list):
|
|
||||||
if layer.sliding_window_size != -1:
|
|
||||||
decode_wrapper = decode_wrapper[0]
|
|
||||||
else:
|
|
||||||
decode_wrapper = decode_wrapper[1]
|
|
||||||
|
|
||||||
if k is not None:
|
if k is not None:
|
||||||
assert v is not None
|
assert v is not None
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class FlashinferUpdater:
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
decode_wrapper=None,
|
decode_wrappers=None,
|
||||||
use_ragged=False,
|
use_ragged=False,
|
||||||
):
|
):
|
||||||
self.forward_mode = forward_mode
|
self.forward_mode = forward_mode
|
||||||
@@ -66,14 +66,14 @@ class FlashinferUpdater:
|
|||||||
self.head_dim = model_runner.model_config.head_dim
|
self.head_dim = model_runner.model_config.head_dim
|
||||||
self.batch_size = len(req_pool_indices)
|
self.batch_size = len(req_pool_indices)
|
||||||
|
|
||||||
self.decode_wrapper = (
|
self.decode_wrappers = (
|
||||||
decode_wrapper or self.model_runner.attn_backend.decode_wrapper
|
decode_wrappers or self.model_runner.attn_backend.decode_wrappers
|
||||||
)
|
)
|
||||||
self.prefill_wrapper_ragged = (
|
self.prefill_wrapper_ragged = (
|
||||||
self.model_runner.attn_backend.prefill_wrapper_ragged
|
self.model_runner.attn_backend.prefill_wrapper_ragged
|
||||||
)
|
)
|
||||||
self.prefill_wrapper_paged = (
|
self.prefill_wrappers_paged = (
|
||||||
self.model_runner.attn_backend.prefill_wrapper_paged
|
self.model_runner.attn_backend.prefill_wrappers_paged
|
||||||
)
|
)
|
||||||
|
|
||||||
self.kv_last_page_len = torch.ones(
|
self.kv_last_page_len = torch.ones(
|
||||||
@@ -142,6 +142,7 @@ class FlashinferUpdater:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _update_decode_indices(self, decode_wrapper):
|
def _update_decode_indices(self, decode_wrapper):
|
||||||
|
assert not isinstance(decode_wrapper, list)
|
||||||
decode_wrapper.end_forward()
|
decode_wrapper.end_forward()
|
||||||
decode_wrapper.begin_forward(
|
decode_wrapper.begin_forward(
|
||||||
self.kv_indptr,
|
self.kv_indptr,
|
||||||
@@ -156,6 +157,9 @@ class FlashinferUpdater:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
def _update_extend_indices(self, ragged_wrapper, paged_wrapper):
|
||||||
|
assert not isinstance(paged_wrapper, list)
|
||||||
|
assert not isinstance(ragged_wrapper, list)
|
||||||
|
|
||||||
# extend part
|
# extend part
|
||||||
qo_indptr = torch.zeros(
|
qo_indptr = torch.zeros(
|
||||||
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
(self.batch_size + 1,), dtype=torch.int32, device="cuda"
|
||||||
@@ -189,11 +193,11 @@ class FlashinferUpdater:
|
|||||||
self._init_indices_no_sliding_window()
|
self._init_indices_no_sliding_window()
|
||||||
|
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.decode_wrapper)
|
self._update_decode_indices(self.decode_wrappers[0])
|
||||||
else:
|
else:
|
||||||
self._update_extend_indices(
|
self._update_extend_indices(
|
||||||
self.prefill_wrapper_ragged,
|
self.prefill_wrapper_ragged,
|
||||||
self.prefill_wrapper_paged,
|
self.prefill_wrappers_paged[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_indices_sliding_window(self):
|
def update_indices_sliding_window(self):
|
||||||
@@ -202,11 +206,11 @@ class FlashinferUpdater:
|
|||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
self._init_indices_sliding_window(wrapper_id)
|
self._init_indices_sliding_window(wrapper_id)
|
||||||
if self.forward_mode.is_decode():
|
if self.forward_mode.is_decode():
|
||||||
self._update_decode_indices(self.decode_wrapper[wrapper_id])
|
self._update_decode_indices(self.decode_wrappers[wrapper_id])
|
||||||
else:
|
else:
|
||||||
self._update_extend_indices(
|
self._update_extend_indices(
|
||||||
None,
|
None,
|
||||||
self.prefill_wrapper_paged[wrapper_id],
|
self.prefill_wrappers_paged[wrapper_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -216,7 +220,7 @@ def update_flashinfer_indices(
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
decode_wrapper=None,
|
decode_wrappers=None,
|
||||||
use_ragged=False,
|
use_ragged=False,
|
||||||
):
|
):
|
||||||
updater = FlashinferUpdater(
|
updater = FlashinferUpdater(
|
||||||
@@ -225,7 +229,7 @@ def update_flashinfer_indices(
|
|||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
decode_wrapper,
|
decode_wrappers,
|
||||||
use_ragged,
|
use_ragged,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user