Simplify logits processor (#2974)
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@
|
|||||||
"""Logits processing."""
|
"""Logits processing."""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import logging
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
ForwardMode,
|
ForwardMode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LogitsProcessorOutput:
|
class LogitsProcessorOutput:
|
||||||
@@ -136,50 +139,61 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits_metadata.forward_mode.is_decode_or_idle()
|
logits_metadata.forward_mode.is_decode_or_idle()
|
||||||
or logits_metadata.forward_mode.is_target_verify()
|
or logits_metadata.forward_mode.is_target_verify()
|
||||||
):
|
):
|
||||||
last_index = None
|
pruned_states = hidden_states
|
||||||
last_hidden = hidden_states
|
sample_indices = None
|
||||||
else:
|
elif (
|
||||||
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
logits_metadata.forward_mode.is_extend()
|
||||||
last_hidden = hidden_states[last_index]
|
and not logits_metadata.extend_return_logprob
|
||||||
|
|
||||||
# Compute logits
|
|
||||||
last_logits = self._get_logits(last_hidden, lm_head)
|
|
||||||
if (
|
|
||||||
not logits_metadata.extend_return_logprob
|
|
||||||
or logits_metadata.capture_hidden_mode.need_capture()
|
|
||||||
):
|
):
|
||||||
# Decode mode or extend mode without return_logprob.
|
# Prefill without input logprobs.
|
||||||
return LogitsProcessorOutput(
|
last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
|
||||||
next_token_logits=last_logits,
|
pruned_states = hidden_states[last_index]
|
||||||
hidden_states=(
|
sample_indices = None
|
||||||
hidden_states
|
|
||||||
if logits_metadata.capture_hidden_mode.is_full()
|
|
||||||
else (
|
|
||||||
last_hidden
|
|
||||||
if logits_metadata.capture_hidden_mode.is_last()
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# Slice the requested tokens to compute logprob
|
# Slice the requested tokens to compute logprob
|
||||||
|
sample_index_pt = -1
|
||||||
|
sample_indices = []
|
||||||
pt, pruned_states, pruned_input_ids = 0, [], []
|
pt, pruned_states, pruned_input_ids = 0, [], []
|
||||||
for start_len, extend_len in zip(
|
for start_len, extend_len in zip(
|
||||||
logits_metadata.extend_logprob_start_lens_cpu,
|
logits_metadata.extend_logprob_start_lens_cpu,
|
||||||
logits_metadata.extend_seq_lens_cpu,
|
logits_metadata.extend_seq_lens_cpu,
|
||||||
):
|
):
|
||||||
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
|
||||||
|
sample_index_pt += extend_len - start_len
|
||||||
|
sample_indices.append(sample_index_pt)
|
||||||
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
|
||||||
pt += extend_len
|
pt += extend_len
|
||||||
|
|
||||||
# Compute the logits of all required tokens
|
|
||||||
pruned_states = torch.cat(pruned_states)
|
pruned_states = torch.cat(pruned_states)
|
||||||
del hidden_states
|
|
||||||
input_token_logits = self._get_logits(pruned_states, lm_head)
|
# Compute logits for both input and sampled tokens.
|
||||||
del pruned_states
|
logits = self._get_logits(pruned_states, lm_head, logits_metadata)
|
||||||
|
sampled_logits = (
|
||||||
|
logits[sample_indices] if sample_indices is not None else logits
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
not logits_metadata.extend_return_logprob
|
||||||
|
or logits_metadata.capture_hidden_mode.need_capture()
|
||||||
|
):
|
||||||
|
# Decode mode or extend mode without return_logprob.
|
||||||
|
return LogitsProcessorOutput(
|
||||||
|
next_token_logits=sampled_logits,
|
||||||
|
hidden_states=(
|
||||||
|
hidden_states
|
||||||
|
if logits_metadata.capture_hidden_mode.is_full()
|
||||||
|
else (
|
||||||
|
pruned_states
|
||||||
|
if logits_metadata.capture_hidden_mode.is_last()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
input_logprobs = logits
|
||||||
|
del hidden_states, logits
|
||||||
|
|
||||||
# Normalize the logprob w/o temperature, top-p
|
# Normalize the logprob w/o temperature, top-p
|
||||||
input_logprobs = input_token_logits
|
|
||||||
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
input_logprobs = self.compute_temp_top_p_normalized_logprobs(
|
||||||
input_logprobs, logits_metadata
|
input_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
@@ -194,17 +208,17 @@ class LogitsProcessor(nn.Module):
|
|||||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||||
|
|
||||||
input_token_logprobs = input_logprobs[
|
input_token_logprobs = input_logprobs[
|
||||||
torch.arange(input_logprobs.shape[0], device="cuda"),
|
torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
|
||||||
torch.cat(
|
torch.cat(
|
||||||
[
|
[
|
||||||
torch.cat(pruned_input_ids)[1:],
|
torch.cat(pruned_input_ids)[1:],
|
||||||
torch.tensor([0], device="cuda"),
|
torch.tensor([0], device=input_logprobs.device),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=sampled_logits,
|
||||||
input_token_logprobs=input_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
input_top_logprobs_val=input_top_logprobs_val,
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
@@ -214,8 +228,11 @@ class LogitsProcessor(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
lm_head: VocabParallelEmbedding,
|
lm_head: VocabParallelEmbedding,
|
||||||
|
logits_metadata: LogitsMetadata,
|
||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""Get logits from hidden_states."""
|
||||||
|
|
||||||
if hasattr(lm_head, "weight"):
|
if hasattr(lm_head, "weight"):
|
||||||
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
logits = torch.matmul(hidden_states, lm_head.weight.T)
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user