Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -22,13 +22,13 @@ class LogitProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
@dataclasses.dataclass
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
self, input_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
output_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, output_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs
return input_top_logprobs, output_top_logprobs
def forward(
self,
@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
output_top_logprobs = None
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_token_logprobs = all_logprobs[
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, logits_metadata
input_token_logprobs, logits_metadata
)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
)