diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 04fd5ffaf..5c9be2095 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -407,9 +407,7 @@ class ModelRpcServer: prefill_logprobs, normalized_logprobs, last_logprobs, - ) = self.model_runner.forward( - batch, ForwardMode.EXTEND, batch.return_logprob - ) + ) = self.model_runner.forward(batch, ForwardMode.EXTEND) if prefill_logprobs is not None: logprobs = prefill_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist() @@ -496,9 +494,7 @@ class ModelRpcServer: # Forward logits, (_, _, last_logprobs) = self.model_runner.forward( - batch, - ForwardMode.DECODE, - batch.return_logprob, + batch, ForwardMode.DECODE ) next_token_ids, _ = batch.sample(logits) next_token_ids = next_token_ids.cpu().tolist() diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index f6d9adc3f..f349819f3 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -367,148 +367,88 @@ class ModelRunner: ) @torch.inference_mode() - def forward_prefill( - self, - input_ids, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - return_logprob, - ): + def forward_prefill(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.PREFILL, tp_size=self.tp_size, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - position_ids_offsets=position_ids_offsets, - out_cache_loc=out_cache_loc, - return_logprob=return_logprob, + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + prefix_lens=batch.prefix_lens, + position_ids_offsets=batch.position_ids_offsets, + out_cache_loc=batch.out_cache_loc, + return_logprob=batch.return_logprob, + ) + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata ) - return self.model.forward(input_ids, input_metadata.positions, input_metadata) @torch.inference_mode() - def forward_extend( - self, - input_ids, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - return_logprob, - ): + def forward_extend(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.EXTEND, tp_size=self.tp_size, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - position_ids_offsets=position_ids_offsets, - out_cache_loc=out_cache_loc, - return_logprob=return_logprob, + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + prefix_lens=batch.prefix_lens, + position_ids_offsets=batch.position_ids_offsets, + out_cache_loc=batch.out_cache_loc, + return_logprob=batch.return_logprob, + ) + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata ) - return self.model.forward(input_ids, input_metadata.positions, input_metadata) @torch.inference_mode() - def forward_decode( - self, - input_ids, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - out_cache_cont_start, - out_cache_cont_end, - return_logprob, - ): + def forward_decode(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.DECODE, tp_size=self.tp_size, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - position_ids_offsets=position_ids_offsets, - out_cache_loc=out_cache_loc, - out_cache_cont_start=out_cache_cont_start, - out_cache_cont_end=out_cache_cont_end, - return_logprob=return_logprob, + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + prefix_lens=batch.prefix_lens, + position_ids_offsets=batch.position_ids_offsets, + out_cache_loc=batch.out_cache_loc, + out_cache_cont_start=batch.out_cache_cont_start, + out_cache_cont_end=batch.out_cache_cont_end, + return_logprob=batch.return_logprob, + ) + return self.model.forward( + batch.input_ids, input_metadata.positions, input_metadata ) - return self.model.forward(input_ids, input_metadata.positions, input_metadata) @torch.inference_mode() - def forward_extend_multi_modal( - self, - input_ids, - pixel_values, - image_sizes, - image_offsets, - req_pool_indices, - seq_lens, - prefix_lens, - position_ids_offsets, - out_cache_loc, - return_logprob, - ): + def forward_extend_multi_modal(self, batch: Batch): input_metadata = InputMetadata.create( self, forward_mode=ForwardMode.EXTEND, tp_size=self.tp_size, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - prefix_lens=prefix_lens, - position_ids_offsets=position_ids_offsets, - out_cache_loc=out_cache_loc, - return_logprob=return_logprob, + req_pool_indices=batch.req_pool_indices, + seq_lens=batch.seq_lens, + prefix_lens=batch.prefix_lens, + position_ids_offsets=batch.position_ids_offsets, + out_cache_loc=batch.out_cache_loc, + return_logprob=batch.return_logprob, ) return self.model.forward( - input_ids, + batch.input_ids, input_metadata.positions, input_metadata, - pixel_values, - image_sizes, - image_offsets, + batch.pixel_values, + batch.image_sizes, + batch.image_offsets, ) - def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False): + def forward(self, batch: Batch, forward_mode: ForwardMode): if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: - kwargs = { - "input_ids": batch.input_ids, - "pixel_values": batch.pixel_values, - "image_sizes": batch.image_sizes, - "image_offsets": batch.image_offsets, - "req_pool_indices": batch.req_pool_indices, - "seq_lens": batch.seq_lens, - "prefix_lens": batch.prefix_lens, - "position_ids_offsets": batch.position_ids_offsets, - "out_cache_loc": batch.out_cache_loc, - "return_logprob": return_logprob, - } - return self.forward_extend_multi_modal(**kwargs) - else: - kwargs = { - "input_ids": batch.input_ids, - "req_pool_indices": batch.req_pool_indices, - "seq_lens": batch.seq_lens, - "prefix_lens": batch.prefix_lens, - "position_ids_offsets": batch.position_ids_offsets, - "out_cache_loc": batch.out_cache_loc, - "return_logprob": return_logprob, - } - - if forward_mode == ForwardMode.DECODE: - kwargs["out_cache_cont_start"] = batch.out_cache_cont_start - kwargs["out_cache_cont_end"] = batch.out_cache_cont_end - return self.forward_decode(**kwargs) + return self.forward_extend_multi_modal(batch) + elif forward_mode == ForwardMode.DECODE: + return self.forward_decode(batch) elif forward_mode == ForwardMode.EXTEND: - return self.forward_extend(**kwargs) + return self.forward_extend(batch) elif forward_mode == ForwardMode.PREFILL: - return self.forward_prefill(**kwargs) + return self.forward_prefill(batch) else: raise ValueError(f"Invaid forward mode: {forward_mode}")