Fix the race condition in overlap mode (#1712)

This commit is contained in:
Lianmin Zheng
2024-10-19 06:50:56 -07:00
committed by GitHub
parent 3db43d1b08
commit 769bf11c05
6 changed files with 21 additions and 38 deletions

View File

@@ -405,9 +405,9 @@ class ScheduleBatch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: BasePrefixCache
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
tree_cache: BasePrefixCache = None
forward_mode: ForwardMode = None
sampling_info: SamplingBatchInfo = None
@@ -874,12 +874,9 @@ class ScheduleBatch:
def copy(self):
return ScheduleBatch(
reqs=self.reqs,
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool=self.token_to_kv_pool,
tree_cache=self.tree_cache,
forward_mode=self.forward_mode,
output_ids=self.output_ids,
sampling_info=self.sampling_info,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
)
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
seq_lens=self.seq_lens.clone(),
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,