Logprobs Refractor (#331)

This commit is contained in:
Liangsheng Yin
2024-03-28 14:34:49 +08:00
committed by GitHub
parent 24e59f5350
commit 3842eba5fa
14 changed files with 385 additions and 152 deletions

View File

@@ -13,76 +13,127 @@ class LogitsProcessor(nn.Module):
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata):
last_index = None
def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, input_metadata: InputMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)
# Compute the last index (the first decode token) of each requeast
# if we are in prefill or extend mode.
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
return normalized_prompt_logprobs
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
if input_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist()
p_cpu = t.indices.cpu().tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens
for i in range(len(input_metadata.extend_seq_lens)):
if extend_seq_lens_cpu[i] == 0:
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist()
ps_cpu = t.indices.cpu().tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
# Get last index for next token prediction, except for DECODE mode.
last_index = None
if input_metadata.forward_mode != ForwardMode.DECODE:
last_index = (
torch.cumsum(
input_metadata.seq_lens - input_metadata.prefix_lens,
dim=0,
dtype=torch.long,
)
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
- 1
)
if not input_metadata.return_logprob:
# When logprob is not requested, only compute the last logits.
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index]
hidden_states = None
# Get the last hidden states and last logits
if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states
else:
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, (None, None, None)
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size]
# Return only last_logits if logprob is not requested
if not input_metadata.return_logprob:
hidden_states = None
return last_logits, (None, None, None, None, None)
else:
# When logprob is requested, compute the logits for all tokens.
logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
logits = tensor_model_parallel_all_gather(logits)
logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
if input_metadata.forward_mode == ForwardMode.DECODE:
all_logits = last_logits
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logits = logits
last_logprobs = all_logprobs
prefill_logprobs = normalized_logprobs = None
return last_logits, (
None,
None,
decode_top_logprobs,
None,
last_logprobs,
)
else:
# Compute the logprobs for the last token of each request.
last_logits = logits[last_index]
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_logprobs = all_logprobs[
prefill_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(
prefill_logprobs, dim=0, dtype=torch.float32
)
start = input_metadata.extend_start_loc.clone()
end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_logprobs[start]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, input_metadata
)
normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
return last_logits, (
prefill_token_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs,
)
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
if __name__ == "__main__":
all_logprobs = torch.tensor(
@@ -93,23 +144,22 @@ if __name__ == "__main__":
)
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
logprobs = all_logprobs[
token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
len_cumsum = torch.cumsum(seq_lens, dim=0)
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
end = start + seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
# assert logprobs == [2, _, 2, 4, _]
print("logprobs", logprobs)
print("token logprobs", token_logprobs)
print("start", start)
print("end", end)
print("sum_logp", sum_logp)