Vectorize logprobs computation (#787)

This commit is contained in:
Ying Sheng
2024-07-28 05:22:14 -07:00
committed by GitHub
parent bcb6611a46
commit c71880f896
2 changed files with 36 additions and 17 deletions

View File

@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module):
@staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
output_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs
else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
[
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(extend_seq_len - 1)
]
)
output_top_logprobs.append(
list(
zip(
values[pt + extend_seq_len - 1][:k],
indices[pt + extend_seq_len - 1][:k],
)
)
)
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len
return input_top_logprobs, output_top_logprobs