Files
sglang/python/sglang/srt/layers/logits_processor.py
2024-07-28 05:22:14 -07:00

266 lines
10 KiB
Python

"""Logits processing."""
import dataclasses
from typing import List, Optional, Union
import torch
from torch import nn
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetadata
@dataclasses.dataclass
class LogitProcessorOutput:
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# The logprobs of the next tokens. shape: [#seq, vocab_size]
next_token_logprobs: torch.Tensor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
return_logprob: bool = False
extend_seq_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
top_logprobs_nums: Optional[List[int]] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
return cls(
forward_mode=input_metadata.forward_mode,
extend_seq_lens=input_metadata.extend_seq_lens,
extend_start_loc=input_metadata.extend_start_loc,
return_logprob=input_metadata.return_logprob,
top_logprobs_nums=input_metadata.top_logprobs_nums,
)
class LogitsProcessor(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs(
self, input_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
)
return normalized_prompt_logprobs
@staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
if logits_metadata.forward_mode == ForwardMode.DECODE:
output_top_logprobs = []
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs
else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
input_top_logprobs.append(
[
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(extend_seq_len - 1)
]
)
output_top_logprobs.append(
list(
zip(
values[pt + extend_seq_len - 1][:k],
indices[pt + extend_seq_len - 1][:k],
)
)
)
pt += extend_seq_len
return input_top_logprobs, output_top_logprobs
def forward(
self,
input_ids,
hidden_states,
weight,
logits_metadata: Union[LogitsMetadata, InputMetadata],
):
if isinstance(logits_metadata, InputMetadata):
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
assert isinstance(logits_metadata, LogitsMetadata)
# Get the last hidden states and last logits for the next token prediction
if logits_metadata.forward_mode == ForwardMode.DECODE:
last_index = None
last_hidden = hidden_states
else:
last_index = (
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()
if hasattr(self.config, "final_logit_softcapping"):
last_logits /= self.config.final_logit_softcapping
last_logits = torch.tanh(last_logits)
last_logits *= self.config.final_logit_softcapping
# Return only last_logits if logprob is not requested
if not logits_metadata.return_logprob:
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
if logits_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
output_top_logprobs = None
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
all_logprobs = all_logits
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
# Get the logprob of top-k tokens
return_top_logprob = any(
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
input_token_logprobs, logits_metadata
)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
)
def test():
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
dtype=torch.float32,
device="cuda",
)
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
end = start + seq_lens - 2
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
# assert logprobs == [2, _, 2, 4, _]
print("token logprobs", token_logprobs)
print("start", start)
print("end", end)
print("sum_logp", sum_logp)
if __name__ == "__main__":
test()