add LogitsMetadata (#604)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Logits processing."""
|
||||
|
||||
import dataclasses
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -31,6 +31,27 @@ class LogitProcessorOutput:
|
||||
decode_top_logprobs: List
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LogitsMetadata:
|
||||
forward_mode: ForwardMode
|
||||
extend_seq_lens: torch.Tensor
|
||||
extend_start_loc: torch.Tensor
|
||||
|
||||
# For logprobs
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: List[int]
|
||||
|
||||
@classmethod
|
||||
def from_input_metadata(cls, input_metadata: InputMetadata):
|
||||
return cls(
|
||||
forward_mode=input_metadata.forward_mode,
|
||||
extend_seq_lens=input_metadata.extend_seq_lens,
|
||||
extend_start_loc=input_metadata.extend_start_loc,
|
||||
return_logprob=input_metadata.return_logprob,
|
||||
top_logprobs_nums=input_metadata.top_logprobs_nums,
|
||||
)
|
||||
|
||||
|
||||
class LogitsProcessor(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
@@ -38,14 +59,14 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, prefill_token_logprobs, input_metadata: InputMetadata
|
||||
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start = logits_metadata.extend_start_loc.clone()
|
||||
end = start + logits_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
@@ -54,17 +75,17 @@ class LogitsProcessor(nn.Module):
|
||||
+ prefill_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||
def _get_top_logprobs(self, all_logprobs, logits_metadata: LogitsMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.tolist()
|
||||
p_cpu = t.indices.tolist()
|
||||
@@ -73,13 +94,13 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
@@ -91,14 +112,24 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
hidden_states,
|
||||
weight,
|
||||
logits_metadata: Union[LogitsMetadata, InputMetadata],
|
||||
):
|
||||
if isinstance(logits_metadata, InputMetadata):
|
||||
logits_metadata = LogitsMetadata.from_input_metadata(logits_metadata)
|
||||
assert isinstance(logits_metadata, LogitsMetadata)
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_index = None
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_index = (
|
||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
torch.cumsum(logits_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
- 1
|
||||
)
|
||||
last_hidden = hidden_states[last_index]
|
||||
@@ -114,7 +145,7 @@ class LogitsProcessor(nn.Module):
|
||||
last_logits *= self.config.final_logit_softcapping
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
if not logits_metadata.return_logprob:
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
@@ -125,7 +156,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
all_logits = last_logits
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
@@ -138,15 +169,15 @@ class LogitsProcessor(nn.Module):
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
return_top_logprob = any(x > 0 for x in logits_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=all_logprobs,
|
||||
@@ -166,7 +197,7 @@ class LogitsProcessor(nn.Module):
|
||||
]
|
||||
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, input_metadata
|
||||
prefill_token_logprobs, logits_metadata
|
||||
)
|
||||
|
||||
return LogitProcessorOutput(
|
||||
|
||||
Reference in New Issue
Block a user