[Minor] Add some utility functions (#1671)
This commit is contained in:
@@ -587,6 +587,8 @@ async def benchmark(
|
|||||||
else:
|
else:
|
||||||
print("Initial test run completed. Starting main benchmark run...")
|
print("Initial test run completed. Starting main benchmark run...")
|
||||||
|
|
||||||
|
time.sleep(1.5)
|
||||||
|
|
||||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
|
|||||||
@@ -392,6 +392,9 @@ class Req:
|
|||||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||||
|
|
||||||
|
|
||||||
|
bid = 0
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScheduleBatch:
|
class ScheduleBatch:
|
||||||
"""Store all inforamtion of a batch."""
|
"""Store all inforamtion of a batch."""
|
||||||
@@ -828,7 +831,11 @@ class ScheduleBatch:
|
|||||||
else:
|
else:
|
||||||
self.sampling_info.regex_fsms = None
|
self.sampling_info.regex_fsms = None
|
||||||
|
|
||||||
|
global bid
|
||||||
|
bid += 1
|
||||||
|
|
||||||
return ModelWorkerBatch(
|
return ModelWorkerBatch(
|
||||||
|
bid=bid,
|
||||||
forward_mode=self.forward_mode,
|
forward_mode=self.forward_mode,
|
||||||
input_ids=self.input_ids,
|
input_ids=self.input_ids,
|
||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
@@ -865,6 +872,8 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelWorkerBatch:
|
class ModelWorkerBatch:
|
||||||
|
# The batch id
|
||||||
|
bid: int
|
||||||
# The forward mode
|
# The forward mode
|
||||||
forward_mode: ForwardMode
|
forward_mode: ForwardMode
|
||||||
# The input ids
|
# The input ids
|
||||||
@@ -893,3 +902,21 @@ class ModelWorkerBatch:
|
|||||||
|
|
||||||
# Sampling info
|
# Sampling info
|
||||||
sampling_info: SamplingBatchInfo
|
sampling_info: SamplingBatchInfo
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return ModelWorkerBatch(
|
||||||
|
bid=self.bid,
|
||||||
|
forward_mode=self.forward_mode,
|
||||||
|
input_ids=self.input_ids.clone(),
|
||||||
|
req_pool_indices=self.req_pool_indices,
|
||||||
|
seq_lens=self.seq_lens,
|
||||||
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
return_logprob=self.return_logprob,
|
||||||
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
|
extend_seq_lens=self.extend_seq_lens,
|
||||||
|
extend_prefix_lens=self.extend_prefix_lens,
|
||||||
|
extend_logprob_start_lens=self.extend_logprob_start_lens,
|
||||||
|
image_inputs=self.image_inputs,
|
||||||
|
lora_paths=self.lora_paths,
|
||||||
|
sampling_info=self.sampling_info.copy(),
|
||||||
|
)
|
||||||
|
|||||||
@@ -710,7 +710,7 @@ class Scheduler:
|
|||||||
next_token_ids
|
next_token_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
if logits_output:
|
if batch.return_logprob:
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if logits_output.next_token_logprobs is not None:
|
if logits_output.next_token_logprobs is not None:
|
||||||
logits_output.next_token_logprobs = (
|
logits_output.next_token_logprobs = (
|
||||||
@@ -786,7 +786,7 @@ class Scheduler:
|
|||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
# Move logprobs to cpu
|
# Move logprobs to cpu
|
||||||
if logits_output.next_token_logprobs is not None:
|
if batch.return_logprob:
|
||||||
next_token_logprobs = logits_output.next_token_logprobs[
|
next_token_logprobs = logits_output.next_token_logprobs[
|
||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
|
|||||||
@@ -202,3 +202,14 @@ class SamplingBatchInfo:
|
|||||||
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
||||||
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
return SamplingBatchInfo(
|
||||||
|
temperatures=self.temperatures,
|
||||||
|
top_ps=self.top_ps,
|
||||||
|
top_ks=self.top_ks,
|
||||||
|
min_ps=self.min_ps,
|
||||||
|
need_min_p_sampling=self.need_min_p_sampling,
|
||||||
|
vocab_size=self.vocab_size,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user