Vectorize logprobs computation (#787)
This commit is contained in:
@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||||
# TODO: vectorize the code below
|
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
output_top_logprobs = []
|
output_top_logprobs = []
|
||||||
for i in range(all_logprobs.shape[0]):
|
max_k = max(logits_metadata.top_logprobs_nums)
|
||||||
k = logits_metadata.top_logprobs_nums[i]
|
ret = all_logprobs.topk(max_k, dim=1)
|
||||||
t = all_logprobs[i].topk(k)
|
values = ret.values.tolist()
|
||||||
v_cpu = t.values.tolist()
|
indices = ret.indices.tolist()
|
||||||
p_cpu = t.indices.tolist()
|
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
||||||
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
||||||
return None, output_top_logprobs
|
return None, output_top_logprobs
|
||||||
else:
|
else:
|
||||||
|
# TODO: vectorize the code below
|
||||||
input_top_logprobs, output_top_logprobs = [], []
|
input_top_logprobs, output_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||||
|
|
||||||
|
max_k = max(logits_metadata.top_logprobs_nums)
|
||||||
|
ret = all_logprobs.topk(max_k, dim=1)
|
||||||
|
values = ret.values.tolist()
|
||||||
|
indices = ret.indices.tolist()
|
||||||
|
|
||||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||||
if extend_seq_len == 0:
|
if extend_seq_len == 0:
|
||||||
input_top_logprobs.append([])
|
input_top_logprobs.append([])
|
||||||
output_top_logprobs.append([])
|
output_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
k = logits_metadata.top_logprobs_nums[i]
|
k = logits_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
|
||||||
vs_cpu = t.values.tolist()
|
|
||||||
ps_cpu = t.indices.tolist()
|
|
||||||
input_top_logprobs.append(
|
input_top_logprobs.append(
|
||||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
[
|
||||||
|
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
||||||
|
for j in range(extend_seq_len - 1)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
output_top_logprobs.append(
|
||||||
|
list(
|
||||||
|
zip(
|
||||||
|
values[pt + extend_seq_len - 1][:k],
|
||||||
|
indices[pt + extend_seq_len - 1][:k],
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
|
||||||
pt += extend_seq_len
|
pt += extend_seq_len
|
||||||
|
|
||||||
return input_top_logprobs, output_top_logprobs
|
return input_top_logprobs, output_top_logprobs
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import dataclasses
|
|||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import os
|
import os
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import transformers
|
import transformers
|
||||||
@@ -469,7 +469,9 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
|
def detokenize_logprob_tokens(
|
||||||
|
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
||||||
|
):
|
||||||
if not decode_to_text:
|
if not decode_to_text:
|
||||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||||
|
|
||||||
@@ -481,9 +483,13 @@ class TokenizerManager:
|
|||||||
]
|
]
|
||||||
|
|
||||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
||||||
for i, t in enumerate(top_logprobs):
|
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
||||||
if t:
|
# We should batch all top-k tokens in all positions.
|
||||||
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
for i, token_top_logprobs in enumerate(top_logprobs):
|
||||||
|
if token_top_logprobs:
|
||||||
|
top_logprobs[i] = self.detokenize_logprob_tokens(
|
||||||
|
token_top_logprobs, decode_to_text
|
||||||
|
)
|
||||||
return top_logprobs
|
return top_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user