Fix logit processor bugs (#427)
This commit is contained in:
@@ -42,26 +42,29 @@ class LogitsProcessor(nn.Module):
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.cpu().tolist()
|
||||
p_cpu = t.indices.cpu().tolist()
|
||||
v_cpu = t.values.tolist()
|
||||
p_cpu = t.indices.tolist()
|
||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, decode_top_logprobs
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens
|
||||
for i in range(len(input_metadata.extend_seq_lens)):
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
||||
for i in range(len(extend_seq_lens_cpu)):
|
||||
if extend_seq_lens_cpu[i] == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||
vs_cpu = t.values.cpu().tolist()
|
||||
ps_cpu = t.indices.cpu().tolist()
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_lens_cpu[i]
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
@@ -99,20 +102,24 @@ class LogitsProcessor(nn.Module):
|
||||
all_logits = all_logits[:, : self.config.vocab_size]
|
||||
|
||||
all_logprobs = all_logits.float()
|
||||
all_logits = None
|
||||
del all_logits
|
||||
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
||||
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
)
|
||||
return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||
all_logprobs, input_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
|
||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||
last_logprobs = all_logprobs
|
||||
return last_logits, (
|
||||
None,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
None,
|
||||
decode_top_logprobs,
|
||||
last_logprobs,
|
||||
)
|
||||
else:
|
||||
@@ -131,9 +138,9 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
return last_logits, (
|
||||
prefill_token_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_top_logprobs,
|
||||
decode_top_logprobs,
|
||||
normalized_prompt_logprobs,
|
||||
last_logprobs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user