diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ec63b4a14..aad52403a 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -77,33 +77,46 @@ class LogitsProcessor(nn.Module): @staticmethod def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): - # TODO: vectorize the code below if logits_metadata.forward_mode == ForwardMode.DECODE: output_top_logprobs = [] - for i in range(all_logprobs.shape[0]): - k = logits_metadata.top_logprobs_nums[i] - t = all_logprobs[i].topk(k) - v_cpu = t.values.tolist() - p_cpu = t.indices.tolist() - output_top_logprobs.append(list(zip(v_cpu, p_cpu))) + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + for i, k in enumerate(logits_metadata.top_logprobs_nums): + output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k]))) return None, output_top_logprobs else: + # TODO: vectorize the code below input_top_logprobs, output_top_logprobs = [], [] pt = 0 extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() + + max_k = max(logits_metadata.top_logprobs_nums) + ret = all_logprobs.topk(max_k, dim=1) + values = ret.values.tolist() + indices = ret.indices.tolist() + for i, extend_seq_len in enumerate(extend_seq_lens_cpu): if extend_seq_len == 0: input_top_logprobs.append([]) output_top_logprobs.append([]) continue k = logits_metadata.top_logprobs_nums[i] - t = all_logprobs[pt : pt + extend_seq_len].topk(k) - vs_cpu = t.values.tolist() - ps_cpu = t.indices.tolist() input_top_logprobs.append( - [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] + [ + list(zip(values[pt + j][:k], indices[pt + j][:k])) + for j in range(extend_seq_len - 1) + ] + ) + output_top_logprobs.append( + list( + zip( + values[pt + extend_seq_len - 1][:k], + indices[pt + extend_seq_len - 1][:k], + ) + ) ) - output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) pt += extend_seq_len return input_top_logprobs, output_top_logprobs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 953520986..20853a645 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -6,7 +6,7 @@ import dataclasses import logging import multiprocessing as mp import os -from typing import Dict, List +from typing import Dict, List, Tuple import numpy as np import transformers @@ -469,7 +469,9 @@ class TokenizerManager: ) return ret - def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool): + def detokenize_logprob_tokens( + self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool + ): if not decode_to_text: return [(logprob, token_id, None) for logprob, token_id in token_logprobs] @@ -481,9 +483,13 @@ class TokenizerManager: ] def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool): - for i, t in enumerate(top_logprobs): - if t: - top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text) + # TODO: The current implementation only batches the detokenization for top-k tokens per single position. + # We should batch all top-k tokens in all positions. + for i, token_top_logprobs in enumerate(top_logprobs): + if token_top_logprobs: + top_logprobs[i] = self.detokenize_logprob_tokens( + token_top_logprobs, decode_to_text + ) return top_logprobs