[Minor] Add some utility functions (#1671)
This commit is contained in:
@@ -587,6 +587,8 @@ async def benchmark(
|
||||
else:
|
||||
print("Initial test run completed. Starting main benchmark run...")
|
||||
|
||||
time.sleep(1.5)
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
@@ -392,6 +392,9 @@ class Req:
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||
|
||||
|
||||
bid = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScheduleBatch:
|
||||
"""Store all inforamtion of a batch."""
|
||||
@@ -828,7 +831,11 @@ class ScheduleBatch:
|
||||
else:
|
||||
self.sampling_info.regex_fsms = None
|
||||
|
||||
global bid
|
||||
bid += 1
|
||||
|
||||
return ModelWorkerBatch(
|
||||
bid=bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids,
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
@@ -865,6 +872,8 @@ class ScheduleBatch:
|
||||
|
||||
@dataclass
|
||||
class ModelWorkerBatch:
|
||||
# The batch id
|
||||
bid: int
|
||||
# The forward mode
|
||||
forward_mode: ForwardMode
|
||||
# The input ids
|
||||
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
|
||||
|
||||
# Sampling info
|
||||
sampling_info: SamplingBatchInfo
|
||||
|
||||
def copy(self):
|
||||
return ModelWorkerBatch(
|
||||
bid=self.bid,
|
||||
forward_mode=self.forward_mode,
|
||||
input_ids=self.input_ids.clone(),
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
extend_prefix_lens=self.extend_prefix_lens,
|
||||
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
||||
image_inputs=self.image_inputs,
|
||||
lora_paths=self.lora_paths,
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
)
|
||||
|
||||
@@ -710,7 +710,7 @@ class Scheduler:
|
||||
next_token_ids
|
||||
)
|
||||
|
||||
if logits_output:
|
||||
if batch.return_logprob:
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
@@ -786,7 +786,7 @@ class Scheduler:
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
# Move logprobs to cpu
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs[
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
|
||||
@@ -202,3 +202,14 @@ class SamplingBatchInfo:
|
||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||
)
|
||||
|
||||
def copy(self):
|
||||
return SamplingBatchInfo(
|
||||
temperatures=self.temperatures,
|
||||
top_ps=self.top_ps,
|
||||
top_ks=self.top_ks,
|
||||
min_ps=self.min_ps,
|
||||
need_min_p_sampling=self.need_min_p_sampling,
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user