Logprobs Refractor (#331)
This commit is contained in:
@@ -13,76 +13,127 @@ class LogitsProcessor(nn.Module):
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||
last_index = None
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, prefill_token_logprobs, input_metadata: InputMetadata
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
|
||||
# Compute the last index (the first decode token) of each requeast
|
||||
# if we are in prefill or extend mode.
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
)
|
||||
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.cpu().tolist()
|
||||
p_cpu = t.indices.cpu().tolist()
|
||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, decode_top_logprobs
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens
|
||||
for i in range(len(input_metadata.extend_seq_lens)):
|
||||
if extend_seq_lens_cpu[i] == 0:
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||
vs_cpu = t.values.cpu().tolist()
|
||||
ps_cpu = t.indices.cpu().tolist()
|
||||
prefill_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
# Get last index for next token prediction, except for DECODE mode.
|
||||
last_index = None
|
||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||
last_index = (
|
||||
torch.cumsum(
|
||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
||||
dim=0,
|
||||
dtype=torch.long,
|
||||
)
|
||||
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||
- 1
|
||||
)
|
||||
|
||||
if not input_metadata.return_logprob:
|
||||
# When logprob is not requested, only compute the last logits.
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
hidden_states = None
|
||||
# Get the last hidden states and last logits
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_hidden = hidden_states
|
||||
else:
|
||||
last_hidden = hidden_states[last_index]
|
||||
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
return last_logits, (None, None, None)
|
||||
last_logits = torch.matmul(last_hidden, weight.T)
|
||||
if self.tp_size > 1:
|
||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||
last_logits = last_logits[:, : self.config.vocab_size]
|
||||
|
||||
# Return only last_logits if logprob is not requested
|
||||
if not input_metadata.return_logprob:
|
||||
hidden_states = None
|
||||
return last_logits, (None, None, None, None, None)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
logits = logits[:, : self.config.vocab_size]
|
||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
all_logits = last_logits
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
if self.tp_size > 1:
|
||||
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||
all_logits = all_logits[:, : self.config.vocab_size]
|
||||
|
||||
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
|
||||
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
)
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logits = logits
|
||||
last_logprobs = all_logprobs
|
||||
prefill_logprobs = normalized_logprobs = None
|
||||
return last_logits, (
|
||||
None,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
None,
|
||||
last_logprobs,
|
||||
)
|
||||
else:
|
||||
# Compute the logprobs for the last token of each request.
|
||||
last_logits = logits[last_index]
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
prefill_logprobs = all_logprobs[
|
||||
prefill_token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
|
||||
start = input_metadata.extend_start_loc.clone()
|
||||
end = start + input_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_logprobs[start]
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, input_metadata
|
||||
)
|
||||
normalized_logprobs = sum_logp / (
|
||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
return last_logits, (
|
||||
prefill_token_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
last_logprobs,
|
||||
)
|
||||
|
||||
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
all_logprobs = torch.tensor(
|
||||
@@ -93,23 +144,22 @@ if __name__ == "__main__":
|
||||
)
|
||||
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
||||
|
||||
logprobs = all_logprobs[
|
||||
token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
||||
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
||||
end = start + seq_lens - 2
|
||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
||||
|
||||
# assert logprobs == [2, _, 2, 4, _]
|
||||
print("logprobs", logprobs)
|
||||
print("token logprobs", token_logprobs)
|
||||
print("start", start)
|
||||
print("end", end)
|
||||
print("sum_logp", sum_logp)
|
||||
|
||||
Reference in New Issue
Block a user