Simplify logits processor (#2974)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-01-18 23:03:49 -08:00
committed by GitHub
parent 93b77c8e8a
commit 23196d5254

View File

@@ -14,6 +14,7 @@
"""Logits processing.""" """Logits processing."""
import dataclasses import dataclasses
import logging
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class LogitsProcessorOutput: class LogitsProcessorOutput:
@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
logits_metadata.forward_mode.is_decode_or_idle() logits_metadata.forward_mode.is_decode_or_idle()
or logits_metadata.forward_mode.is_target_verify() or logits_metadata.forward_mode.is_target_verify()
): ):
last_index = None pruned_states = hidden_states
last_hidden = hidden_states sample_indices = None
else: elif (
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 logits_metadata.forward_mode.is_extend()
last_hidden = hidden_states[last_index] and not logits_metadata.extend_return_logprob
# Compute logits
last_logits = self._get_logits(last_hidden, lm_head)
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
): ):
# Decode mode or extend mode without return_logprob. # Prefill without input logprobs.
return LogitsProcessorOutput( last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
next_token_logits=last_logits, pruned_states = hidden_states[last_index]
hidden_states=( sample_indices = None
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
last_hidden
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
)
else: else:
# Slice the requested tokens to compute logprob # Slice the requested tokens to compute logprob
sample_index_pt = -1
sample_indices = []
pt, pruned_states, pruned_input_ids = 0, [], [] pt, pruned_states, pruned_input_ids = 0, [], []
for start_len, extend_len in zip( for start_len, extend_len in zip(
logits_metadata.extend_logprob_start_lens_cpu, logits_metadata.extend_logprob_start_lens_cpu,
logits_metadata.extend_seq_lens_cpu, logits_metadata.extend_seq_lens_cpu,
): ):
pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
sample_index_pt += extend_len - start_len
sample_indices.append(sample_index_pt)
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
pt += extend_len pt += extend_len
# Compute the logits of all required tokens
pruned_states = torch.cat(pruned_states) pruned_states = torch.cat(pruned_states)
del hidden_states
input_token_logits = self._get_logits(pruned_states, lm_head) # Compute logits for both input and sampled tokens.
del pruned_states logits = self._get_logits(pruned_states, lm_head, logits_metadata)
sampled_logits = (
logits[sample_indices] if sample_indices is not None else logits
)
if (
not logits_metadata.extend_return_logprob
or logits_metadata.capture_hidden_mode.need_capture()
):
# Decode mode or extend mode without return_logprob.
return LogitsProcessorOutput(
next_token_logits=sampled_logits,
hidden_states=(
hidden_states
if logits_metadata.capture_hidden_mode.is_full()
else (
pruned_states
if logits_metadata.capture_hidden_mode.is_last()
else None
)
),
)
else:
input_logprobs = logits
del hidden_states, logits
# Normalize the logprob w/o temperature, top-p # Normalize the logprob w/o temperature, top-p
input_logprobs = input_token_logits
input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs = self.compute_temp_top_p_normalized_logprobs(
input_logprobs, logits_metadata input_logprobs, logits_metadata
) )
@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
input_top_logprobs_val = input_top_logprobs_idx = None input_top_logprobs_val = input_top_logprobs_idx = None
input_token_logprobs = input_logprobs[ input_token_logprobs = input_logprobs[
torch.arange(input_logprobs.shape[0], device="cuda"), torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
torch.cat( torch.cat(
[ [
torch.cat(pruned_input_ids)[1:], torch.cat(pruned_input_ids)[1:],
torch.tensor([0], device="cuda"), torch.tensor([0], device=input_logprobs.device),
] ]
), ),
] ]
return LogitsProcessorOutput( return LogitsProcessorOutput(
next_token_logits=last_logits, next_token_logits=sampled_logits,
input_token_logprobs=input_token_logprobs, input_token_logprobs=input_token_logprobs,
input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_val=input_top_logprobs_val,
input_top_logprobs_idx=input_top_logprobs_idx, input_top_logprobs_idx=input_top_logprobs_idx,
@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding, lm_head: VocabParallelEmbedding,
logits_metadata: LogitsMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Get logits from hidden_states."""
if hasattr(lm_head, "weight"): if hasattr(lm_head, "weight"):
logits = torch.matmul(hidden_states, lm_head.weight.T) logits = torch.matmul(hidden_states, lm_head.weight.T)
else: else: