Fix the race condition in overlap mode (#1712)
This commit is contained in:
@@ -405,9 +405,9 @@ class ScheduleBatch:
|
||||
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool
|
||||
token_to_kv_pool: BaseTokenToKVPool
|
||||
tree_cache: BasePrefixCache
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
forward_mode: ForwardMode = None
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
@@ -874,12 +874,9 @@ class ScheduleBatch:
|
||||
def copy(self):
|
||||
return ScheduleBatch(
|
||||
reqs=self.reqs,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
tree_cache=self.tree_cache,
|
||||
forward_mode=self.forward_mode,
|
||||
output_ids=self.output_ids,
|
||||
sampling_info=self.sampling_info,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
decoding_reqs=self.decoding_reqs,
|
||||
)
|
||||
|
||||
@@ -929,7 +926,7 @@ class ModelWorkerBatch:
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids.clone(),
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
seq_lens=self.seq_lens.clone(),
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
|
||||
Reference in New Issue
Block a user