Update CI threshold & Improve code style (#2159)
This commit is contained in:
@@ -437,9 +437,12 @@ class ScheduleBatch:
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# For utility
|
||||
# Batch configs
|
||||
model_config: ModelConfig = None
|
||||
forward_mode: ForwardMode = None
|
||||
enable_overlap: bool = False
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo = None
|
||||
next_batch_sampling_info: SamplingBatchInfo = None
|
||||
|
||||
@@ -488,10 +491,11 @@ class ScheduleBatch:
|
||||
def init_new(
|
||||
cls,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool,
|
||||
token_to_kv_pool,
|
||||
tree_cache,
|
||||
model_config,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: ReqToTokenPool,
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -499,6 +503,7 @@ class ScheduleBatch:
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
return_logprob=any(req.return_logprob for req in reqs),
|
||||
has_stream=any(req.stream for req in reqs),
|
||||
has_grammar=any(req.grammar for req in reqs),
|
||||
@@ -612,7 +617,7 @@ class ScheduleBatch:
|
||||
|
||||
assert len(self.out_cache_loc) == self.extend_num_tokens
|
||||
|
||||
def prepare_for_extend(self, enable_overlap_schedule: bool = False):
|
||||
def prepare_for_extend(self):
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
|
||||
bs = len(self.reqs)
|
||||
@@ -706,7 +711,7 @@ class ScheduleBatch:
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=enable_overlap_schedule,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def mix_with_running(self, running_batch: "ScheduleBatch"):
|
||||
@@ -897,7 +902,7 @@ class ScheduleBatch:
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
|
||||
self.input_ids = self.output_ids
|
||||
@@ -914,7 +919,7 @@ class ScheduleBatch:
|
||||
else:
|
||||
locs = self.seq_lens
|
||||
|
||||
if enable_overlap:
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, locs), self.out_cache_loc
|
||||
|
||||
Reference in New Issue
Block a user