Add output_ids into ScheduleBatch (#1659)
This commit is contained in:
@@ -247,7 +247,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop(self):
|
||||
def event_loop_normal(self):
|
||||
self.last_batch = None
|
||||
|
||||
while True:
|
||||
@@ -411,9 +411,10 @@ class Scheduler:
|
||||
throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
|
||||
self.num_generated_tokens = 0
|
||||
self.last_stats_tic = time.time()
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
logger.info(
|
||||
f"Decode batch. "
|
||||
f"#running-req: {len(self.running_batch.reqs)}, "
|
||||
f"#running-req: {num_running_reqs}, "
|
||||
f"#token: {num_used}, "
|
||||
f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
|
||||
f"gen throughput (token/s): {throughput:.2f}, "
|
||||
@@ -659,6 +660,7 @@ class Scheduler:
|
||||
)
|
||||
else:
|
||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids
|
||||
else: # embedding or reward model
|
||||
assert batch.extend_num_tokens != 0
|
||||
@@ -753,7 +755,7 @@ class Scheduler:
|
||||
# Inflight request would get a new req idx
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
self.stream_output(batch)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids = result
|
||||
@@ -793,7 +795,7 @@ class Scheduler:
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
self.stream_output(batch)
|
||||
|
||||
self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
|
||||
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
|
||||
@@ -872,7 +874,7 @@ class Scheduler:
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def handle_finished_requests(self, batch: ScheduleBatch):
|
||||
def stream_output(self, batch: ScheduleBatch):
|
||||
output_rids = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
@@ -949,6 +951,9 @@ class Scheduler:
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
@@ -976,9 +981,6 @@ class Scheduler:
|
||||
)
|
||||
)
|
||||
|
||||
# Remove finished reqs: update batch tensors
|
||||
batch.filter_batch(unfinished_indices)
|
||||
|
||||
def flush_cache(self):
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
@@ -1060,7 +1062,7 @@ def run_scheduler_process(
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
||||
pipe_writer.send("ready")
|
||||
scheduler.event_loop()
|
||||
scheduler.event_loop_normal()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
|
||||
Reference in New Issue
Block a user