Clean up logits processor (#558)

This commit is contained in:
Lianmin Zheng
2024-06-22 00:25:24 -07:00
committed by GitHub
parent 92cb93f390
commit 303ef8883e
3 changed files with 130 additions and 116 deletions

View File

@@ -1,5 +1,8 @@
"""Logits processing."""
import dataclasses
from typing import List
import torch
from torch import nn
from vllm.distributed import (
@@ -10,6 +13,24 @@ from vllm.distributed import (
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass
class LogitProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
@@ -39,6 +60,7 @@ class LogitsProcessor(nn.Module):
return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
# TODO: vectorize the code below
if input_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
for i in range(all_logprobs.shape[0]):
@@ -51,7 +73,6 @@ class LogitsProcessor(nn.Module):
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
@@ -71,18 +92,15 @@ class LogitsProcessor(nn.Module):
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
# Get last index for next token prediction, except for DECODE mode.
last_index = None
if input_metadata.forward_mode != ForwardMode.DECODE:
# Get the last hidden states and last logits for the next token prediction
if input_metadata.forward_mode == ForwardMode.DECODE:
last_index = None
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
# Get the last hidden states and last logits
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
@@ -92,8 +110,14 @@ class LogitsProcessor(nn.Module):
# Return only last_logits if logprob is not requested
if not input_metadata.return_logprob:
hidden_states = None
return last_logits, (None, None, None, None, None)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
if input_metadata.forward_mode == ForwardMode.DECODE:
@@ -108,6 +132,7 @@ class LogitsProcessor(nn.Module):
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
@@ -117,16 +142,15 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs
return last_logits, (
None,
None,
None,
decode_top_logprobs,
last_logprobs,
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=all_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
)
else:
# Compute the logprobs for the last token of each request.
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
@@ -139,12 +163,14 @@ class LogitsProcessor(nn.Module):
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata
)
return last_logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
last_logprobs,
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
)