Simplify flashinfer indices update for prefill (#2074)

Co-authored-by: kavioyu <kavioyu@tencent.com>
Co-authored-by: kavioyu <kavioyu@gmail.com>
This commit is contained in:
Lianmin Zheng
2024-11-18 00:02:36 -08:00
parent df7fe4521a
commit 4af3f889fc
8 changed files with 87 additions and 40 deletions

View File

@@ -109,6 +109,7 @@ class ForwardBatch:
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
extend_prefix_lens_cpu: Optional[List[int]] = None
extend_seq_lens_cpu: Optional[List[int]] = None
extend_logprob_start_lens_cpu: Optional[List[int]] = None
@@ -250,6 +251,7 @@ class ForwardBatch:
ret.positions, ret.extend_start_loc = compute_position_triton(
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
)
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens