Fix logit processor bugs (#427)

This commit is contained in:
Lianmin Zheng
2024-05-12 04:54:07 -07:00
committed by GitHub
parent 7023f413c6
commit aee4f523cf
26 changed files with 166 additions and 257 deletions

View File

@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
for i in range(all_logprobs.shape[0]):
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.cpu().tolist()
p_cpu = t.indices.cpu().tolist()
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens
for i in range(len(input_metadata.extend_seq_lens)):
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
vs_cpu = t.values.cpu().tolist()
ps_cpu = t.indices.cpu().tolist()
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
all_logits = all_logits[:, : self.config.vocab_size]
all_logprobs = all_logits.float()
all_logits = None
del all_logits
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
all_logprobs, input_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
if input_metadata.forward_mode == ForwardMode.DECODE:
last_logprobs = all_logprobs
return last_logits, (
None,
None,
decode_top_logprobs,
None,
decode_top_logprobs,
last_logprobs,
)
else:
@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
)
return last_logits, (
prefill_token_logprobs,
normalized_prompt_logprobs,
prefill_top_logprobs,
decode_top_logprobs,
normalized_prompt_logprobs,
last_logprobs,
)