[Minor] Add some utility functions (#1671)
This commit is contained in:
@@ -392,6 +392,9 @@ class Req:
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||
|
||||
|
||||
bid = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
@@ -828,7 +831,11 @@ class ScheduleBatch:
|
||||
else:
|
||||
self.sampling_info.regex_fsms = None
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
|
||||
return ModelWorkerBatch(
|
||||
bid=bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
@@ -865,6 +872,8 @@ class ScheduleBatch:
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
bid: int
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
# The input ids
|
||||
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
def copy(self):
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids.clone(),
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
extend_prefix_lens=self.extend_prefix_lens,
|
||||
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
||||
image_inputs=self.image_inputs,
|
||||
lora_paths=self.lora_paths,
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user