Simplify flashinfer indices update for prefill (#2074)
Co-authored-by: kavioyu <kavioyu@tencent.com> Co-authored-by: kavioyu <kavioyu@gmail.com>
This commit is contained in:
@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
|
||||
"""
|
||||
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
|
||||
# Some heuristics to check whether to use ragged forward
|
||||
use_ragged = False
|
||||
if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
|
||||
use_ragged = True
|
||||
|
||||
extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
|
||||
else:
|
||||
use_ragged = False
|
||||
extend_no_prefix = False
|
||||
|
||||
self.indices_updater_prefill.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
prefix_lens,
|
||||
use_ragged=use_ragged,
|
||||
encoder_lens=forward_batch.encoder_lens,
|
||||
@@ -334,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(
|
||||
self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -344,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
decode_wrappers: List,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -362,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
decode_wrappers: List,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
@@ -393,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
|
||||
|
||||
def update_cross_attention(
|
||||
self,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
decode_wrappers=None,
|
||||
encoder_lens=None,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
|
||||
@@ -424,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
|
||||
def call_begin_forward(
|
||||
self,
|
||||
wrapper,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
@@ -494,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
assert self.attn_backend.num_wrappers == 1
|
||||
self.update = self.update_single_wrapper
|
||||
|
||||
def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
|
||||
def update(
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_single_wrapper(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
self,
|
||||
req_pool_indices: torch.Tnesor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
else:
|
||||
paged_kernel_lens = seq_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[0],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
None,
|
||||
@@ -520,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
def update_sliding_window(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -529,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
seq_lens,
|
||||
torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
|
||||
)
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
else:
|
||||
# full attention
|
||||
paged_kernel_lens = seq_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
|
||||
kv_start_idx = seq_lens - paged_kernel_lens
|
||||
|
||||
self.call_begin_forward(
|
||||
@@ -539,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self.wrappers_paged[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
kv_start_idx,
|
||||
@@ -548,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
def update_cross_attention(
|
||||
self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
encoder_lens: torch.Tensor,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
# normal attention
|
||||
paged_kernel_lens = seq_lens
|
||||
kv_start_idx = encoder_lens
|
||||
paged_kernel_lens_sum = seq_lens_sum
|
||||
else:
|
||||
# cross attention
|
||||
paged_kernel_lens = encoder_lens
|
||||
kv_start_idx = torch.zeros_like(encoder_lens)
|
||||
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
|
||||
|
||||
self.call_begin_forward(
|
||||
self.wrapper_ragged,
|
||||
self.wrappers_paged[wrapper_id],
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
kv_start_idx,
|
||||
@@ -577,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
self,
|
||||
wrapper_ragged,
|
||||
wrapper_paged,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
seq_lens,
|
||||
prefix_lens,
|
||||
kv_start_idx,
|
||||
kv_indptr,
|
||||
qo_indptr,
|
||||
use_ragged,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
seq_lens: torch.Tensor,
|
||||
prefix_lens: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
|
||||
@@ -64,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
max_extend_len = None
|
||||
else:
|
||||
start_loc = attn_logits = max_seq_len = None
|
||||
prefix_lens = forward_batch.extend_prefix_lens
|
||||
max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
|
||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||
|
||||
self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ class ForwardBatch:
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens_cpu: Optional[List[int]] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
@@ -250,6 +251,7 @@ class ForwardBatch:
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
|
||||
@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
|
||||
@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
|
||||
|
||||
# Fill in the placeholder for the image
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
pt = 0
|
||||
for i in range(bs):
|
||||
if not need_vision[i]:
|
||||
|
||||
@@ -616,7 +616,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
||||
|
||||
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
||||
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
|
||||
for i, image in enumerate(forward_batch.image_inputs):
|
||||
if image is None:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user