Simplify logits processor (#2974)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardMode,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitsProcessorOutput:
|
||||
@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
):
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
# Compute logits
|
||||
last_logits = self._get_logits(last_hidden, lm_head)
|
||||
if (
|
||||
not logits_metadata.extend_return_logprob
|
||||
or logits_metadata.capture_hidden_mode.need_capture()
|
||||
pruned_states = hidden_states
|
||||
sample_indices = None
|
||||
elif (
|
||||
logits_metadata.forward_mode.is_extend()
|
||||
and not logits_metadata.extend_return_logprob
|
||||
):
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
hidden_states=(
|
||||
hidden_states
|
||||
if logits_metadata.capture_hidden_mode.is_full()
|
||||
else (
|
||||
last_hidden
|
||||
if logits_metadata.capture_hidden_mode.is_last()
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
# Prefill without input logprobs.
|
||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||
pruned_states = hidden_states[last_index]
|
||||
sample_indices = None
|
||||
else:
|
||||
# Slice the requested tokens to compute logprob
|
||||
sample_index_pt = -1
|
||||
sample_indices = []
|
||||
pt, pruned_states, pruned_input_ids = 0, [], []
|
||||
for start_len, extend_len in zip(
|
||||
logits_metadata.extend_logprob_start_lens_cpu,
|
||||
logits_metadata.extend_seq_lens_cpu,
|
||||
):
|
||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||
sample_index_pt += extend_len - start_len
|
||||
sample_indices.append(sample_index_pt)
|
||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||
pt += extend_len
|
||||
|
||||
# Compute the logits of all required tokens
|
||||
pruned_states = torch.cat(pruned_states)
|
||||
del hidden_states
|
||||
input_token_logits = self._get_logits(pruned_states, lm_head)
|
||||
del pruned_states
|
||||
|
||||
# Compute logits for both input and sampled tokens.
|
||||
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||
sampled_logits = (
|
||||
logits[sample_indices] if sample_indices is not None else logits
|
||||
)
|
||||
|
||||
if (
|
||||
not logits_metadata.extend_return_logprob
|
||||
or logits_metadata.capture_hidden_mode.need_capture()
|
||||
):
|
||||
# Decode mode or extend mode without return_logprob.
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=sampled_logits,
|
||||
hidden_states=(
|
||||
hidden_states
|
||||
if logits_metadata.capture_hidden_mode.is_full()
|
||||
else (
|
||||
pruned_states
|
||||
if logits_metadata.capture_hidden_mode.is_last()
|
||||
else None
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
input_logprobs = logits
|
||||
del hidden_states, logits
|
||||
|
||||
# Normalize the logprob w/o temperature, top-p
|
||||
input_logprobs = input_token_logits
|
||||
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||
input_logprobs, logits_metadata
|
||||
)
|
||||
@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
|
||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||
|
||||
input_token_logprobs = input_logprobs[
|
||||
torch.arange(input_logprobs.shape[0], device="cuda"),
|
||||
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat(pruned_input_ids)[1:],
|
||||
torch.tensor([0], device="cuda"),
|
||||
torch.tensor([0], device=input_logprobs.device),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logits=sampled_logits,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
logits_metadata: LogitsMetadata,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Get logits from hidden_states."""
|
||||
|
||||
if hasattr(lm_head, "weight"):
|
||||
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user