Optimize retract (#440)

This commit is contained in:
Liangsheng Yin
2024-05-26 00:07:26 +08:00
committed by GitHub
parent 2cea6146d8
commit f06e90c2cf
7 changed files with 298 additions and 113 deletions

View File

@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
prefill_top_logprobs, decode_top_logprobs = [], []
pt = 0
# NOTE: the GPU-CPU overhead can be reduced
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
for i in range(len(extend_seq_lens_cpu)):
if extend_seq_lens_cpu[i] == 0:
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
continue
k = input_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_lens_cpu[i]
pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
)
if __name__ == "__main__":
def test():
all_logprobs = torch.tensor(
# s s s
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
@@ -173,3 +174,7 @@ if __name__ == "__main__":
print("start", start)
print("end", end)
print("sum_logp", sum_logp)
if __name__ == "__main__":
test()