diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f38397212..e9ed66f4f 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,7 +1,7 @@ """Logits processing.""" import dataclasses -from typing import List, Union +from typing import List, Optional, Union import torch from torch import nn @@ -34,11 +34,11 @@ class LogitProcessorOutput: @dataclasses.dataclass class LogitsMetadata: forward_mode: ForwardMode - return_logprob: bool + return_logprob: bool = False - extend_seq_lens: torch.Tensor = None - extend_start_loc: torch.Tensor = None - top_logprobs_nums: List[int] = None + extend_seq_lens: Optional[torch.Tensor] = None + extend_start_loc: Optional[torch.Tensor] = None + top_logprobs_nums: Optional[List[int]] = None @classmethod def from_input_metadata(cls, input_metadata: InputMetadata): @@ -79,7 +79,8 @@ class LogitsProcessor(nn.Module): return normalized_prompt_logprobs - def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata): + @staticmethod + def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): # TODO: vectorize the code below if logits_metadata.forward_mode == ForwardMode.DECODE: decode_top_logprobs = [] @@ -156,36 +157,48 @@ class LogitsProcessor(nn.Module): else: # When logprob is requested, compute the logits for all tokens. if logits_metadata.forward_mode == ForwardMode.DECODE: - all_logits = last_logits - else: - all_logits = torch.matmul(hidden_states, weight.T) - if self.tp_size > 1: - all_logits = tensor_model_parallel_all_gather(all_logits) - all_logits = all_logits[:, : self.config.vocab_size].float() + last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1) - all_logprobs = all_logits - del all_logits - all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) - - # Get the logprob of top-k tokens - return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums) - if return_top_logprob: - prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs( - all_logprobs, logits_metadata + # Get the logprob of top-k tokens + return_top_logprob = any( + x > 0 for x in logits_metadata.top_logprobs_nums ) - else: - prefill_top_logprobs = decode_top_logprobs = None + if return_top_logprob: + decode_top_logprobs = self.get_top_logprobs( + last_logprobs, logits_metadata + )[1] + else: + decode_top_logprobs = None - if logits_metadata.forward_mode == ForwardMode.DECODE: return LogitProcessorOutput( next_token_logits=last_logits, - next_token_logprobs=all_logprobs, + next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, prefill_token_logprobs=None, prefill_top_logprobs=None, decode_top_logprobs=decode_top_logprobs, ) else: + all_logits = torch.matmul(hidden_states, weight.T) + if self.tp_size > 1: + all_logits = tensor_model_parallel_all_gather(all_logits) + all_logits = all_logits[:, : self.config.vocab_size].float() + + all_logprobs = all_logits + del all_logits + all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1) + + # Get the logprob of top-k tokens + return_top_logprob = any( + x > 0 for x in logits_metadata.top_logprobs_nums + ) + if return_top_logprob: + prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs( + all_logprobs, logits_metadata + ) + else: + prefill_top_logprobs = decode_top_logprobs = None + last_logprobs = all_logprobs[last_index] # Compute the logprobs and normalized logprobs for the prefill tokens. diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 1095481ee..2bdb33cff 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import ( + LogitProcessorOutput, + LogitsMetadata, + LogitsProcessor, +) from sglang.srt.managers.controller.infer_batch import ( Batch, ForwardMode, @@ -185,7 +189,6 @@ class CudaGraphRunner: def replay(self, batch: Batch): assert batch.out_cache_loc is not None - assert not batch.return_logprob raw_bs = len(batch.reqs) # Pad @@ -218,23 +221,29 @@ class CudaGraphRunner: output = self.output_buffers[bs] # Unpad - if bs == raw_bs: - return output - else: + if bs != raw_bs: output = LogitProcessorOutput( next_token_logits=output.next_token_logits[:raw_bs], - next_token_logprobs=( - output.next_token_logprobs[:raw_bs] - if output.next_token_logprobs is not None - else None - ), + next_token_logprobs=None, normalized_prompt_logprobs=None, prefill_token_logprobs=None, prefill_top_logprobs=None, - decode_top_logprobs=( - output.decode_top_logprobs[:raw_bs] - if output.decode_top_logprobs is not None - else None - ), + decode_top_logprobs=None, ) + + # Extract logprobs + if batch.return_logprob: + output.next_token_logprobs = torch.nn.functional.log_softmax( + output.next_token_logits, dim=-1 + ) + return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) + if return_top_logprob: + logits_metadata = LogitsMetadata( + forward_mode=ForwardMode.DECODE, + top_logprobs_nums=batch.top_logprobs_nums, + ) + output.decode_top_logprobs = LogitsProcessor.get_top_logprobs( + output.next_token_logprobs, logits_metadata + )[1] + return output