Clean up logits processor (#558)
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed import (
|
||||
@@ -10,6 +13,24 @@ from vllm.distributed import (
|
||||
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitProcessorOutput:
|
||||
# The logits of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logits: torch.Tensor
|
||||
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
||||
next_token_logprobs: torch.Tensor
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
||||
prefill_token_logprobs: torch.Tensor
|
||||
|
||||
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
prefill_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
decode_top_logprobs: List
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
# Get last index for next token prediction, except for DECODE mode.
|
||||
last_index = None
|
||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Get the last hidden states and last logits
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
hidden_states = None
|
||||
return last_logits, (None, None, None, None, None)
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logprobs = all_logprobs
|
||||
return last_logits, (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=all_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
)
|
||||
else:
|
||||
# Compute the logprobs for the last token of each request.
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, input_metadata
|
||||
)
|
||||
return last_logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user