diff --git a/python/sglang/srt/layers/get_selected_logprob.py b/python/sglang/srt/layers/get_selected_logprob.py deleted file mode 100644 index 60e5b3ba2..000000000 --- a/python/sglang/srt/layers/get_selected_logprob.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import triton -import triton.language as tl -from sglang.srt.utils import wrap_kernel_launcher - - -@triton.jit -def _fwd_segmented_gather( - all_logits, - len_add_1, - cum_len, - input_ids, - logprobs, - max_seq_len, - voc_size: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - cur_req = tl.program_id(0) - cur_l = tl.load(len_add_1 + cur_req) - cum_l = tl.load(cum_len + cur_req) - - for i in range(0, (max_seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE): - off = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = off < cur_l - 1 - - idx = tl.load(input_ids + cum_l - cur_l + off + 1, mask=mask) - data = tl.load(all_logits + (cum_l - cur_l + off) * voc_size + idx, mask=mask) - tl.store(logprobs + cum_l - cur_l - cur_req + off, data, mask=mask) - - -cached_kernel = None - - -def get_selected_logprob(all_logits, len_add_1, input_ids, logprobs): - cum_len = torch.cumsum(len_add_1, dtype=torch.int32, dim=0) - voc_size = all_logits.shape[1] - grid = (len_add_1.shape[0], 1, 1) - max_seq_len = len_add_1.max().item() - - global cached_kernel - if cached_kernel: - cached_kernel( - grid, - 4, - all_logits, - len_add_1, - cum_len, - input_ids, - logprobs, - max_seq_len, - ) - return - - _fwd_segmented_gather[grid]( - all_logits, - len_add_1, - cum_len, - input_ids, - logprobs, - max_seq_len, - voc_size, - BLOCK_SIZE=128, - ) - cached_kernel = wrap_kernel_launcher(_fwd_segmented_gather) - - -if __name__ == "__main__": - all_logits = torch.tensor( - # s s s - [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], - dtype=torch.float32, - device="cuda", - ) - len_add_1 = torch.tensor([2, 3], dtype=torch.int32, device="cuda") - input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") - logprobs = torch.empty((3), dtype=torch.float32, device="cuda") - get_selected_logprobs(all_logits, len_add_1, input_ids, logprobs) - print(logprobs) - # assert logprobs == [2, 2, 4] diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 1442b6db7..7c819c34c 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -1,5 +1,4 @@ import torch -from sglang.srt.layers.get_selected_logprob import get_selected_logprob from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata from torch import nn from vllm.model_executor.parallel_utils.communication_op import ( @@ -54,25 +53,56 @@ class LogitsProcessor(nn.Module): normalized_logprobs = compute_normalized_logprobs( all_logprobs, - input_metadata.seq_lens - input_metadata.prefix_lens, input_ids, + input_metadata.extend_seq_lens, + input_metadata.extend_start_loc, ) last_logits = logits[last_index] return last_logits, normalized_logprobs -def compute_normalized_logprobs(all_logprobs, len_add_1, input_ids): - # assert all_logprobs.shape[0] == input_ids.shape[0] == torch.sum(len_add_1) - logprobs = torch.zeros( - (all_logprobs.shape[0] - len_add_1.shape[0]), dtype=torch.float32, device="cuda" +def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc): + logprobs = all_logprobs[ + torch.arange(all_logprobs.shape[0], device="cuda"), + torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + ] + logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + + start = start_loc.clone() + end = start + seq_lens - 2 + start.clamp_(min=0, max=logprobs.shape[0] - 1) + end.clamp_(min=0, max=logprobs.shape[0] - 1) + sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] + return sum_logp / ((seq_lens - 1).clamp(min=1)) + + +if __name__ == "__main__": + all_logprobs = torch.tensor( + # s s s + [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]], + dtype=torch.float32, + device="cuda", ) - get_selected_logprob(all_logprobs, len_add_1, input_ids, logprobs) - cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) - end = torch.cumsum(len_add_1.sub_(1), dim=0) - start = torch.cat((torch.tensor([0], device="cuda"), end[:-1]), 0) - end.sub_(1) - torch.cuda.synchronize() - sum_logp = cumsum[end] - cumsum[start] + logprobs[start] - res = sum_logp / len_add_1 - return res + seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda") + input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda") + logprobs = torch.zeros(5, dtype=torch.float32, device="cuda") + + logprobs = all_logprobs[ + torch.arange(all_logprobs.shape[0], device="cuda"), + torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), + ] + logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) + + len_cumsum = torch.cumsum(seq_lens, dim=0) + start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0) + end = start + seq_lens - 2 + start.clamp_(min=0, max=logprobs.shape[0] - 1) + end.clamp_(min=0, max=logprobs.shape[0] - 1) + sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] + + # assert logprobs == [2, _, 2, 4, _] + print("logprobs", logprobs) + print("start", start) + print("end", end) + print("sum_logp", sum_logp) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index b200a7295..071ec4efe 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -1,7 +1,7 @@ +import logging from dataclasses import dataclass from enum import Enum, auto from typing import List -import logging import numpy as np import torch @@ -13,7 +13,6 @@ from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.model_loader import _set_default_torch_dtype from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel - logger = logging.getLogger("model_runner") @@ -112,7 +111,7 @@ class InputMetadata: def init_extend_args(self): self.extend_seq_lens = self.seq_lens - self.prefix_lens self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], 0) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.max_extend_len = int(torch.max(self.extend_seq_lens)) @classmethod @@ -262,7 +261,7 @@ class ModelRunner: if model_class is None: raise ValueError(f"Unsupported architectures: {architectures}") - logger.info("load weight begin.") + logger.info(f"Rank {self.tp_rank}: load weight begin.") # Load weights linear_method = None @@ -287,7 +286,7 @@ class ModelRunner: ) self.model = model.eval() - logger.info("load weight end.") + logger.info(f"Rank {self.tp_rank}: load weight end.") def profile_max_num_token(self, total_gpu_memory): available_gpu_memory = get_available_gpu_memory( @@ -308,8 +307,9 @@ class ModelRunner: self.max_total_num_token = self.profile_max_num_token(total_gpu_memory) if self.max_total_num_token <= 0: - raise RuntimeError("Not enought memory. " - "Please try to increase --mem-fraction-static.") + raise RuntimeError( + "Not enought memory. " "Please try to increase --mem-fraction-static." + ) self.req_to_token_pool = ReqToTokenPool( int(self.max_total_num_token / self.model_config.context_len * 256),