diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 6194ca1d1..a05398812 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -232,17 +232,18 @@ def extend(reqs, model_runner): model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() + next_token_ids = model_runner.sample(logits_output, forward_batch) return next_token_ids, logits_output.next_token_logits, batch @torch.inference_mode() def decode(input_token_ids, batch, model_runner): - batch.prepare_for_decode(input_token_ids) + batch.output_ids = input_token_ids + batch.prepare_for_decode() model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) logits_output = model_runner.forward(forward_batch) - next_token_ids = model_runner.sample(logits_output, forward_batch).tolist() + next_token_ids = model_runner.sample(logits_output, forward_batch) return next_token_ids, logits_output.next_token_logits @@ -252,6 +253,7 @@ def correctness_test( bench_args, tp_rank, ): + configure_logger(server_args, prefix=f" TP{tp_rank}") rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None # Load the model @@ -279,8 +281,9 @@ def correctness_test( output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] for _ in range(bench_args.output_len[0] - 1): next_token_ids, _ = decode(next_token_ids, batch, model_runner) + next_token_ids_list = next_token_ids.tolist() for i in range(len(reqs)): - output_ids[i].append(next_token_ids[i]) + output_ids[i].append(next_token_ids_list[i]) # Print for i in range(len(reqs)): diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 869c529e3..b4248d5ec 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -410,6 +410,8 @@ class ScheduleBatch: seq_lens: torch.Tensor = None out_cache_loc: torch.Tensor = None + output_ids: torch.Tensor = None + # For processing logprobs return_logprob: bool = False top_logprobs_nums: Optional[List[int]] = None @@ -720,19 +722,12 @@ class ScheduleBatch: return jump_forward_reqs - def prepare_for_decode(self, input_ids=None): + def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE - if input_ids is None: - input_ids = [ - r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] - for r in self.reqs - ] - - self.input_ids = torch.tensor( - input_ids, dtype=torch.int32, device=self.seq_lens.device - ) + self.input_ids = self.output_ids self.seq_lens.add_(1) + self.output_ids = None # Alloc mem bs = len(self.reqs) @@ -759,6 +754,7 @@ class ScheduleBatch: self.req_pool_indices = self.req_pool_indices[new_indices] self.seq_lens = self.seq_lens[new_indices] self.out_cache_loc = None + self.output_ids = self.output_ids[new_indices] self.return_logprob = any(req.return_logprob for req in self.reqs) if self.return_logprob: self.top_logprobs_nums = [ @@ -783,6 +779,8 @@ class ScheduleBatch: ) self.seq_lens = torch.concat([self.seq_lens, other.seq_lens]) self.out_cache_loc = None + if self.output_ids is not None: + self.output_ids = torch.concat([self.output_ids, other.output_ids]) if self.return_logprob and other.return_logprob: self.top_logprobs_nums.extend(other.top_logprobs_nums) elif self.return_logprob: @@ -838,7 +836,9 @@ class ScheduleBatch: token_to_kv_pool=self.token_to_kv_pool, tree_cache=self.tree_cache, forward_mode=self.forward_mode, - output_token_ids=self.output_token_ids, + output_ids=self.output_ids, + sampling_info=self.sampling_info, + decoding_reqs=self.decoding_reqs, ) def __str__(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 03ae37d66..e59679ffa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -247,7 +247,7 @@ class Scheduler: ) @torch.inference_mode() - def event_loop(self): + def event_loop_normal(self): self.last_batch = None while True: @@ -411,9 +411,10 @@ class Scheduler: throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) self.num_generated_tokens = 0 self.last_stats_tic = time.time() + num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 logger.info( f"Decode batch. " - f"#running-req: {len(self.running_batch.reqs)}, " + f"#running-req: {num_running_reqs}, " f"#token: {num_used}, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"gen throughput (token/s): {throughput:.2f}, " @@ -659,6 +660,7 @@ class Scheduler: ) else: next_token_ids = torch.full((batch.batch_size(),), 0) + batch.output_ids = next_token_ids ret = logits_output, next_token_ids else: # embedding or reward model assert batch.extend_num_tokens != 0 @@ -753,7 +755,7 @@ class Scheduler: # Inflight request would get a new req idx self.req_to_token_pool.free(req.req_pool_idx) - self.handle_finished_requests(batch) + self.stream_output(batch) def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result @@ -793,7 +795,7 @@ class Scheduler: if req.top_logprobs_num > 0: req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) - self.handle_finished_requests(batch) + self.stream_output(batch) self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: @@ -872,7 +874,7 @@ class Scheduler: return num_input_logprobs - def handle_finished_requests(self, batch: ScheduleBatch): + def stream_output(self, batch: ScheduleBatch): output_rids = [] output_meta_info = [] output_finished_reason: List[BaseFinishReason] = [] @@ -949,6 +951,9 @@ class Scheduler: } output_meta_info.append(meta_info) + # Remove finished reqs: update batch tensors + batch.filter_batch(unfinished_indices) + # Send to detokenizer if output_rids: if self.is_generation: @@ -976,9 +981,6 @@ class Scheduler: ) ) - # Remove finished reqs: update batch tensors - batch.filter_batch(unfinished_indices) - def flush_cache(self): if len(self.waiting_queue) == 0 and ( self.running_batch is None or len(self.running_batch.reqs) == 0 @@ -1060,7 +1062,7 @@ def run_scheduler_process( try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) pipe_writer.send("ready") - scheduler.event_loop() + scheduler.event_loop_normal() except Exception: msg = get_exception_traceback() logger.error(msg)