diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4afccf73a..2f07973b8 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -587,6 +587,8 @@ async def benchmark( else: print("Initial test run completed. Starting main benchmark run...") + time.sleep(1.5) + pbar = None if disable_tqdm else tqdm(total=len(input_requests)) benchmark_start_time = time.perf_counter() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7f0c243b7..b84fc5562 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -392,6 +392,9 @@ class Req: return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, " +bid = 0 + + @dataclass class ScheduleBatch: """Store all inforamtion of a batch.""" @@ -828,7 +831,11 @@ class ScheduleBatch: else: self.sampling_info.regex_fsms = None + global bid + bid += 1 + return ModelWorkerBatch( + bid=bid, forward_mode=self.forward_mode, input_ids=self.input_ids, req_pool_indices=self.req_pool_indices, @@ -865,6 +872,8 @@ class ScheduleBatch: @dataclass class ModelWorkerBatch: + # The batch id + bid: int # The forward mode forward_mode: ForwardMode # The input ids @@ -893,3 +902,21 @@ class ModelWorkerBatch: # Sampling info sampling_info: SamplingBatchInfo + + def copy(self): + return ModelWorkerBatch( + bid=self.bid, + forward_mode=self.forward_mode, + input_ids=self.input_ids.clone(), + req_pool_indices=self.req_pool_indices, + seq_lens=self.seq_lens, + out_cache_loc=self.out_cache_loc, + return_logprob=self.return_logprob, + top_logprobs_nums=self.top_logprobs_nums, + extend_seq_lens=self.extend_seq_lens, + extend_prefix_lens=self.extend_prefix_lens, + extend_logprob_start_lens=self.extend_logprob_start_lens, + image_inputs=self.image_inputs, + lora_paths=self.lora_paths, + sampling_info=self.sampling_info.copy(), + ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index c124a0d5d..40463d016 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -710,7 +710,7 @@ class Scheduler: next_token_ids ) - if logits_output: + if batch.return_logprob: # Move logprobs to cpu if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = ( @@ -786,7 +786,7 @@ class Scheduler: self.num_generated_tokens += len(batch.reqs) # Move logprobs to cpu - if logits_output.next_token_logprobs is not None: + if batch.return_logprob: next_token_logprobs = logits_output.next_token_logprobs[ torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 4a38fc087..779af5101 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -202,3 +202,14 @@ class SamplingBatchInfo: self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias, other.logit_bias, len(self), len(other), self.device ) + + def copy(self): + return SamplingBatchInfo( + temperatures=self.temperatures, + top_ps=self.top_ps, + top_ks=self.top_ks, + min_ps=self.min_ps, + need_min_p_sampling=self.need_min_p_sampling, + vocab_size=self.vocab_size, + device=self.device, + )