# Adapt from https://github.com/vllm-project/vllm/blob/main/vllm/v1/worker/gpu/sample/logprob.py. # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. import torch from vllm.triton_utils import tl, triton @triton.jit def _topk_log_softmax_kernel( output_ptr, logits_ptr, logits_stride, topk_ids_ptr, topk, vocab_size, BLOCK_SIZE: tl.constexpr, PADDED_TOPK: tl.constexpr, ): req_idx = tl.program_id(0) row_ptr = logits_ptr + req_idx * logits_stride max_val = float("-inf") for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) max_val = tl.max(tl.maximum(logits, max_val)) max_val = max_val.to(tl.float32) # type: ignore se = 0.0 for i in range(0, vocab_size, BLOCK_SIZE): block = i + tl.arange(0, BLOCK_SIZE) logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) # NOTE(woosuk): Make sure that logits and all following operations use FP32. logits = logits.to(tl.float32) # NOTE(wangx700): tl.where does not support int64 so we cast it to float32. block = block.to(tl.float32) e = tl.exp(logits - max_val) e = tl.where(block < vocab_size, e, 0.0) se += tl.sum(e) lse = tl.log(se) k_offset = tl.arange(0, PADDED_TOPK) k_mask = k_offset < topk topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) logits = tl.load(row_ptr + topk_ids, mask=k_mask) logits = logits.to(tl.float32) o = logits - max_val - lse tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) def compute_token_logprobs(logits: torch.Tensor, token_ids: torch.Tensor) -> torch.Tensor: batch_size, vocab_size = logits.shape token_ids = token_ids.to(torch.int64) num_logprobs = token_ids.shape[1] logprobs = logits.new_empty((batch_size, num_logprobs), dtype=torch.float32) _topk_log_softmax_kernel[(batch_size,)]( logprobs, logits, logits.stride(0), token_ids, num_logprobs, vocab_size, BLOCK_SIZE=1024, # type: ignore # NOTE(wangx700): PADDED_TOPK must be at least 2 to avoid # num_logprobs=1 getting wrong results. PADDED_TOPK=max(triton.next_power_of_2(num_logprobs), 2), ) return logprobs