diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 7a03e162f..ac6b1fb6f 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -260,7 +260,7 @@ def correctness_test( # Decode output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))] - for _ in range(bench_args.output_len[0]): + for _ in range(bench_args.output_len[0] - 1): next_token_ids, _ = decode(next_token_ids, batch, model_runner) for i in range(len(reqs)): output_ids[i].append(next_token_ids[i]) @@ -311,7 +311,7 @@ def latency_test_run_once( # Decode decode_latencies = [] - for i in range(output_len): + for i in range(output_len - 1): torch.cuda.synchronize() tic = time.time() next_token_ids, _ = decode(next_token_ids, batch, model_runner) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2ab041726..c4c91c711 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -429,7 +429,7 @@ class ScheduleBatch: def prepare_for_extend(self, vocab_size: int): self.forward_mode = ForwardMode.EXTEND - bs = self.batch_size() + bs = len(self.reqs) reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] extend_num_tokens = sum(len(ids) for ids in input_ids) @@ -509,7 +509,7 @@ class ScheduleBatch: self.extend_logprob_start_lens_cpu.extend([0] * running_bs) def check_decode_mem(self): - bs = self.batch_size() + bs = len(self.reqs) if self.token_to_kv_pool.available_size() >= bs: return True @@ -680,14 +680,12 @@ class ScheduleBatch: r.output_ids[-1] if r.output_ids else r.origin_input_ids[-1] for r in self.reqs ] - else: - self.sampling_info.penalizer_orchestrator.cumulate_input_tokens(input_ids) self.input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda") self.seq_lens.add_(1) # Alloc mem - bs = self.batch_size() + bs = len(self.reqs) self.out_cache_loc = self.alloc_token_slots(bs) self.req_to_token_pool.req_to_token[ diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index fe9afc9f3..414424e5b 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -215,6 +215,7 @@ class ModelTpServer: self.new_token_ratio_decay = global_config.new_token_ratio_decay self.do_not_get_new_batch = False + @torch.inference_mode() def exposed_step(self, recv_reqs: List): try: # Recv requests @@ -246,7 +247,6 @@ class ModelTpServer: self.out_pyobjs = [] return ret - @torch.inference_mode() def forward_step(self): if self.do_not_get_new_batch and self.current_inflight_req is None: new_batch = None diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4815fbc56..4e81abec1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -97,14 +97,12 @@ class InputMetadata: self.modalities = [r.modalities for r in reqs] def compute_positions(self, batch: ScheduleBatch): - position_ids_offsets = batch.position_ids_offsets - if self.forward_mode.is_decode(): if True: self.positions = self.seq_lens - 1 else: # Deprecated - self.positions = (self.seq_lens - 1) + position_ids_offsets + self.positions = (self.seq_lens - 1) + batch.position_ids_offsets else: if True: self.positions = torch.tensor( @@ -119,7 +117,7 @@ class InputMetadata: ) else: # Deprecated - position_ids_offsets_cpu = position_ids_offsets.cpu().numpy() + position_ids_offsets_cpu = batch.position_ids_offsets.cpu().numpy() self.positions = torch.tensor( np.concatenate( [ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 049a43840..5096257be 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -467,7 +467,6 @@ class ModelRunner: logger.info("Capture cuda graph begin. This can take up to several minutes.") self.cuda_graph_runner = CudaGraphRunner(self) - @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): if self.server_args.lora_paths is not None: self.lora_manager.prepare_lora_batch(batch) @@ -481,7 +480,6 @@ class ModelRunner: batch.input_ids, input_metadata.positions, input_metadata ) - @torch.inference_mode() def forward_extend(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch(self, batch) if self.server_args.lora_paths is not None: @@ -500,7 +498,6 @@ class ModelRunner: get_embedding=True, ) - @torch.inference_mode() def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata = InputMetadata.from_schedule_batch(self, batch) return self.model.forward( diff --git a/scripts/playground/reference_hf.py b/scripts/playground/reference_hf.py index 1eb7b0dd2..56c06a174 100644 --- a/scripts/playground/reference_hf.py +++ b/scripts/playground/reference_hf.py @@ -45,7 +45,7 @@ def normal_text(args): "The capital of the United Kindom is", "Today is a sunny day and I like", ] - max_new_tokens = 17 + max_new_tokens = 16 torch.cuda.set_device(0)