Simplify flashinfer indices update for prefill (#2074)
Co-authored-by: kavioyu <kavioyu@tencent.com> Co-authored-by: kavioyu <kavioyu@gmail.com>
This commit is contained in:
@@ -109,6 +109,7 @@ class ForwardBatch:
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens_cpu: Optional[List[int]] = None
|
||||
extend_seq_lens_cpu: Optional[List[int]] = None
|
||||
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
||||
|
||||
@@ -250,6 +251,7 @@ class ForwardBatch:
|
||||
ret.positions, ret.extend_start_loc = compute_position_triton(
|
||||
ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
|
||||
)
|
||||
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user