Vectorize logprobs computation (#787)

This commit is contained in:
Ying Sheng
2024-07-28 05:22:14 -07:00
committed by GitHub
parent bcb6611a46
commit c71880f896
2 changed files with 36 additions and 17 deletions

View File

@@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module):
@staticmethod @staticmethod
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE: if logits_metadata.forward_mode == ForwardMode.DECODE:
output_top_logprobs = [] output_top_logprobs = []
for i in range(all_logprobs.shape[0]): max_k = max(logits_metadata.top_logprobs_nums)
k = logits_metadata.top_logprobs_nums[i] ret = all_logprobs.topk(max_k, dim=1)
t = all_logprobs[i].topk(k) values = ret.values.tolist()
v_cpu = t.values.tolist() indices = ret.indices.tolist()
p_cpu = t.indices.tolist() for i, k in enumerate(logits_metadata.top_logprobs_nums):
output_top_logprobs.append(list(zip(v_cpu, p_cpu))) output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
return None, output_top_logprobs return None, output_top_logprobs
else: else:
# TODO: vectorize the code below
input_top_logprobs, output_top_logprobs = [], [] input_top_logprobs, output_top_logprobs = [], []
pt = 0 pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
max_k = max(logits_metadata.top_logprobs_nums)
ret = all_logprobs.topk(max_k, dim=1)
values = ret.values.tolist()
indices = ret.indices.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu): for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0: if extend_seq_len == 0:
input_top_logprobs.append([]) input_top_logprobs.append([])
output_top_logprobs.append([]) output_top_logprobs.append([])
continue continue
k = logits_metadata.top_logprobs_nums[i] k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
input_top_logprobs.append( input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] [
list(zip(values[pt + j][:k], indices[pt + j][:k]))
for j in range(extend_seq_len - 1)
]
)
output_top_logprobs.append(
list(
zip(
values[pt + extend_seq_len - 1][:k],
indices[pt + extend_seq_len - 1][:k],
)
)
) )
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len pt += extend_seq_len
return input_top_logprobs, output_top_logprobs return input_top_logprobs, output_top_logprobs

View File

@@ -6,7 +6,7 @@ import dataclasses
import logging import logging
import multiprocessing as mp import multiprocessing as mp
import os import os
from typing import Dict, List from typing import Dict, List, Tuple
import numpy as np import numpy as np
import transformers import transformers
@@ -469,7 +469,9 @@ class TokenizerManager:
) )
return ret return ret
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool): def detokenize_logprob_tokens(
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
):
if not decode_to_text: if not decode_to_text:
return [(logprob, token_id, None) for logprob, token_id in token_logprobs] return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
@@ -481,9 +483,13 @@ class TokenizerManager:
] ]
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
for i, t in enumerate(top_logprobs): # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
if t: # We should batch all top-k tokens in all positions.
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) for i, token_top_logprobs in enumerate(top_logprobs):
if token_top_logprobs:
top_logprobs[i] = self.detokenize_logprob_tokens(
token_top_logprobs, decode_to_text
)
return top_logprobs return top_logprobs