diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 10f264677..e5794f052 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,6 +14,7 @@ """Logits processing.""" import dataclasses +import logging from typing import List, Optional, Union import torch @@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, ) +logger = logging.getLogger(__name__) + @dataclasses.dataclass class LogitsProcessorOutput: @@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module): logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() ): - last_index = None - last_hidden = hidden_states - else: - last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - last_hidden = hidden_states[last_index] - - # Compute logits - last_logits = self._get_logits(last_hidden, lm_head) - if ( - not logits_metadata.extend_return_logprob - or logits_metadata.capture_hidden_mode.need_capture() + pruned_states = hidden_states + sample_indices = None + elif ( + logits_metadata.forward_mode.is_extend() + and not logits_metadata.extend_return_logprob ): - # Decode mode or extend mode without return_logprob. - return LogitsProcessorOutput( - next_token_logits=last_logits, - hidden_states=( - hidden_states - if logits_metadata.capture_hidden_mode.is_full() - else ( - last_hidden - if logits_metadata.capture_hidden_mode.is_last() - else None - ) - ), - ) + # Prefill without input logprobs. + last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 + pruned_states = hidden_states[last_index] + sample_indices = None else: # Slice the requested tokens to compute logprob + sample_index_pt = -1 + sample_indices = [] pt, pruned_states, pruned_input_ids = 0, [], [] for start_len, extend_len in zip( logits_metadata.extend_logprob_start_lens_cpu, logits_metadata.extend_seq_lens_cpu, ): pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + sample_index_pt += extend_len - start_len + sample_indices.append(sample_index_pt) pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) pt += extend_len - # Compute the logits of all required tokens pruned_states = torch.cat(pruned_states) - del hidden_states - input_token_logits = self._get_logits(pruned_states, lm_head) - del pruned_states + + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + sampled_logits = ( + logits[sample_indices] if sample_indices is not None else logits + ) + + if ( + not logits_metadata.extend_return_logprob + or logits_metadata.capture_hidden_mode.need_capture() + ): + # Decode mode or extend mode without return_logprob. + return LogitsProcessorOutput( + next_token_logits=sampled_logits, + hidden_states=( + hidden_states + if logits_metadata.capture_hidden_mode.is_full() + else ( + pruned_states + if logits_metadata.capture_hidden_mode.is_last() + else None + ) + ), + ) + else: + input_logprobs = logits + del hidden_states, logits # Normalize the logprob w/o temperature, top-p - input_logprobs = input_token_logits input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) @@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module): input_top_logprobs_val = input_top_logprobs_idx = None input_token_logprobs = input_logprobs[ - torch.arange(input_logprobs.shape[0], device="cuda"), + torch.arange(input_logprobs.shape[0], device=input_logprobs.device), torch.cat( [ torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), + torch.tensor([0], device=input_logprobs.device), ] ), ] return LogitsProcessorOutput( - next_token_logits=last_logits, + next_token_logits=sampled_logits, input_token_logprobs=input_token_logprobs, input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_idx=input_top_logprobs_idx, @@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module): self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, + logits_metadata: LogitsMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Get logits from hidden_states.""" + if hasattr(lm_head, "weight"): logits = torch.matmul(hidden_states, lm_head.weight.T) else: