Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -22,13 +22,13 @@ class LogitProcessorOutput:
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
||||
prefill_token_logprobs: torch.Tensor
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
input_token_logprobs: torch.Tensor
|
||||
|
||||
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
prefill_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
decode_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
input_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
output_top_logprobs: List
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
||||
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = logits_metadata.extend_start_loc.clone()
|
||||
end = start + logits_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_token_logprobs[start]
|
||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (
|
||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
|
||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
output_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.tolist()
|
||||
p_cpu = t.indices.tolist()
|
||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, decode_top_logprobs
|
||||
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, output_top_logprobs
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
input_top_logprobs, output_top_logprobs = [], []
|
||||
pt = 0
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
input_top_logprobs.append([])
|
||||
output_top_logprobs.append([])
|
||||
continue
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
input_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_len
|
||||
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
return input_top_logprobs, output_top_logprobs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
decode_top_logprobs = self.get_top_logprobs(
|
||||
output_top_logprobs = self.get_top_logprobs(
|
||||
last_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
decode_top_logprobs = None
|
||||
output_top_logprobs = None
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
|
||||
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
input_top_logprobs = output_top_logprobs = None
|
||||
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
prefill_token_logprobs = all_logprobs[
|
||||
input_token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, logits_metadata
|
||||
input_token_logprobs, logits_metadata
|
||||
)
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user