### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
- vLLM version: v0.17.0
- vLLM main:
8b6325758c
Signed-off-by: HarpSealCC [844291270@qq.com](mailto:844291270@qq.com)
Signed-off-by: l30072083 <liuchengzhuo1@h-partners.com>
Co-authored-by: l30072083 <liuchengzhuo1@h-partners.com>
66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/utils.py
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.jit(do_not_specialize=["num_reqs"])
|
|
def prepare_inputs_padded_kernel(
|
|
cu_num_draft_tokens_ptr, # [num_reqs]
|
|
valid_sampled_tokens_count_ptr, # [num_reqs]
|
|
query_start_loc_gpu_ptr, # [num_reqs + 1]
|
|
token_indices_to_sample_ptr, # [num_reqs] (output)
|
|
num_rejected_tokens_gpu_ptr,
|
|
num_reqs, # tl.int32
|
|
BLOCK_SIZE: tl.constexpr,
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
num_programs = tl.num_programs(axis=0)
|
|
|
|
# Grid-Stride Loop:
|
|
block_start_step = num_programs * BLOCK_SIZE
|
|
|
|
for block_start in tl.range(pid * BLOCK_SIZE, num_reqs, block_start_step):
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < num_reqs
|
|
|
|
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
|
|
# cumulative sum (first entry is the first value, not zero).
|
|
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask)
|
|
|
|
prev_indices = offsets - 1
|
|
has_prev = offsets > 0
|
|
cu_draft_prev = tl.load(
|
|
cu_num_draft_tokens_ptr + prev_indices,
|
|
mask=mask & has_prev,
|
|
other=0,
|
|
)
|
|
|
|
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr)
|
|
|
|
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask)
|
|
num_rejected = num_draft_tokens + 1 - valid_count
|
|
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
|
|
|
|
# query_start_loc[req_idx + 1] is the start position of the next request,
|
|
# which is one past the last token of this request.
|
|
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1
|
|
|
|
index_to_sample = q_last_tok_idx - num_rejected
|
|
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
|
|
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)
|