[Minor] Add some utility functions (#1671)

This commit is contained in:
Lianmin Zheng
2024-10-14 20:08:03 -07:00
committed by GitHub
parent cd0be7489f
commit 4a292f670d
4 changed files with 42 additions and 2 deletions

View File

@@ -392,6 +392,9 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
bid = 0
@dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
@@ -828,7 +831,11 @@ class ScheduleBatch:
else:
self.sampling_info.regex_fsms = None
global bid
bid += 1
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
@@ -865,6 +872,8 @@ class ScheduleBatch:
@dataclass
class ModelWorkerBatch:
# The batch id
bid: int
# The forward mode
forward_mode: ForwardMode
# The input ids
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
)