[Minor] Add some utility functions (#1671)

This commit is contained in:
Lianmin Zheng
2024-10-14 20:08:03 -07:00
committed by GitHub
parent cd0be7489f
commit 4a292f670d
4 changed files with 42 additions and 2 deletions

View File

@@ -587,6 +587,8 @@ async def benchmark(
else:
print("Initial test run completed. Starting main benchmark run...")
time.sleep(1.5)
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()

View File

@@ -392,6 +392,9 @@ class Req:
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
bid = 0
@dataclass
class ScheduleBatch:
"""Store all inforamtion of a batch."""
@@ -828,7 +831,11 @@ class ScheduleBatch:
else:
self.sampling_info.regex_fsms = None
global bid
bid += 1
return ModelWorkerBatch(
bid=bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids,
req_pool_indices=self.req_pool_indices,
@@ -865,6 +872,8 @@ class ScheduleBatch:
@dataclass
class ModelWorkerBatch:
# The batch id
bid: int
# The forward mode
forward_mode: ForwardMode
# The input ids
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
# Sampling info
sampling_info: SamplingBatchInfo
def copy(self):
return ModelWorkerBatch(
bid=self.bid,
forward_mode=self.forward_mode,
input_ids=self.input_ids.clone(),
req_pool_indices=self.req_pool_indices,
seq_lens=self.seq_lens,
out_cache_loc=self.out_cache_loc,
return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums,
extend_seq_lens=self.extend_seq_lens,
extend_prefix_lens=self.extend_prefix_lens,
extend_logprob_start_lens=self.extend_logprob_start_lens,
image_inputs=self.image_inputs,
lora_paths=self.lora_paths,
sampling_info=self.sampling_info.copy(),
)

View File

@@ -710,7 +710,7 @@ class Scheduler:
next_token_ids
)
if logits_output:
if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
@@ -786,7 +786,7 @@ class Scheduler:
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
if batch.return_logprob:
next_token_logprobs = logits_output.next_token_logprobs[
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,

View File

@@ -202,3 +202,14 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device
)
def copy(self):
return SamplingBatchInfo(
temperatures=self.temperatures,
top_ps=self.top_ps,
top_ks=self.top_ks,
min_ps=self.min_ps,
need_min_p_sampling=self.need_min_p_sampling,
vocab_size=self.vocab_size,
device=self.device,
)