Optimize retract (#440)
This commit is contained in:
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
||||
for i in range(len(extend_seq_lens_cpu)):
|
||||
if extend_seq_lens_cpu[i] == 0:
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_lens_cpu[i]
|
||||
pt += extend_seq_len
|
||||
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test():
|
||||
all_logprobs = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
@@ -173,3 +174,7 @@ if __name__ == "__main__":
|
||||
print("start", start)
|
||||
print("end", end)
|
||||
print("sum_logp", sum_logp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
Reference in New Issue
Block a user