Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741)

This commit is contained in:
Lianmin Zheng
2024-10-21 01:43:16 -07:00
committed by GitHub
parent cf470fea32
commit 09603c6dc9
8 changed files with 98 additions and 43 deletions

View File

@@ -416,7 +416,6 @@ class ScheduleBatch:
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
@@ -424,9 +423,13 @@ class ScheduleBatch:
input_ids: torch.Tensor = None
req_pool_indices: torch.Tensor = None
seq_lens: torch.Tensor = None
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None
output_ids: torch.Tensor = None
# The sum of all sequence lengths
seq_lens_sum: int = None
# For processing logprobs
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
@@ -435,7 +438,6 @@ class ScheduleBatch:
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
running_bs: int = None
decoding_reqs: List[Req] = None
# Stream
@@ -549,10 +551,12 @@ class ScheduleBatch:
self.device, non_blocking=True
)
self.extend_num_tokens = extend_num_tokens
self.out_cache_loc = out_cache_loc
self.seq_lens_sum = sum(seq_lens)
if self.return_logprob:
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
self.extend_lens = [r.extend_input_len for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
@@ -571,12 +575,11 @@ class ScheduleBatch:
input_ids = torch.cat([self.input_ids, running_batch.input_ids])
out_cache_loc = torch.cat([self.out_cache_loc, running_batch.out_cache_loc])
extend_num_tokens = self.extend_num_tokens + running_bs
self.merge_batch(running_batch)
self.input_ids = input_ids
self.out_cache_loc = out_cache_loc
self.extend_num_tokens = extend_num_tokens
self.extend_num_tokens += running_bs
# NOTE: prefix_indices is what has been cached, but we don't cache each decode step
self.prefix_lens.extend(
@@ -775,6 +778,7 @@ class ScheduleBatch:
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
)
self.seq_lens.add_(1)
self.seq_lens_sum += bs
def filter_batch(
self,
@@ -805,6 +809,7 @@ class ScheduleBatch:
self.req_pool_indices = self.req_pool_indices[new_indices]
self.seq_lens = self.seq_lens[new_indices]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
self.output_ids = self.output_ids[new_indices]
self.return_logprob = any(req.return_logprob for req in self.reqs)
if self.return_logprob:
@@ -828,6 +833,7 @@ class ScheduleBatch:
)
self.seq_lens = torch.concat([self.seq_lens, other.seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
if self.output_ids is not None:
self.output_ids = torch.concat([self.output_ids, other.output_ids])
if self.return_logprob and other.return_logprob:
@@ -873,9 +879,11 @@ class ScheduleBatch:
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
seq_lens_sum=self.seq_lens_sum,
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens,
extend_logprob_start_lens=extend_logprob_start_lens,
@@ -917,6 +925,9 @@ class ModelWorkerBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The sum of all sequence lengths
seq_lens_sum: int
# The memory pool operation records
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
@@ -925,6 +936,7 @@ class ModelWorkerBatch:
top_logprobs_nums: Optional[List[int]]
# For extend
extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]]
extend_prefix_lens: Optional[List[int]]
extend_logprob_start_lens: Optional[List[int]]