Simplify flashinfer indices update for prefill (#2074)

Co-authored-by: kavioyu <kavioyu@tencent.com>
Co-authored-by: kavioyu <kavioyu@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-11-18 00:02:36 -08:00
parent df7fe4521a
commit 4af3f889fc
8 changed files with 87 additions and 40 deletions

View File

@@ -345,7 +345,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:

View File

@@ -169,7 +169,7 @@ class LlavaVidForCausalLM(nn.Module):
# Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0
for i in range(bs):
if not need_vision[i]:

View File

@@ -616,7 +616,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
inputs_embeds = self.model.embed_tokens(input_ids)
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
for i, image in enumerate(forward_batch.image_inputs):
if image is None:
continue