Files
xc-llm-ascend/vllm_ascend/ops/triton/reject_sample.py

440 lines
17 KiB
Python
Raw Normal View History

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
def cal_grid_and_block_size(batch_size: int):
vectorcore_num = get_vectorcore_num()
if batch_size <= vectorcore_num:
grid = batch_size
block_size = 1
else:
grid = vectorcore_num
block_size = triton.next_power_of_2(triton.cdiv(batch_size, grid))
return grid, block_size
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew_1(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
tl.store(output_token_ids_ptr + position * 2 + 1, bonus_token_id)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_greedy_sample_spec_len_1_triton(
output_token_ids_ptr, # [batch_size, 2]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr,
vec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
draft_token_id = tl.load(draft_token_ids_ptr + offset, mask)
target_argmax_id = tl.load(target_argmax_ptr + offset, mask)
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
# Add validity check for pos within the loop
for pos in tl.range(0, BLOCK_SIZE):
# Calculate the global position of the current token
global_pos = block_idx * BLOCK_SIZE + pos
if global_pos < vec_len:
draft_token_id1 = get_element(draft_token_id, (pos,))
target_argmax1 = get_element(target_argmax_id, (pos,))
if draft_token_id1 == target_argmax1:
bonus_renew_1(
bonus_token_ids_ptr,
global_pos,
output_token_ids_ptr,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
):
bonus_token_id = tl.load(bonus_token_ids_ptr + position)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(output_token_ids_ptr + position * (max_spec_len + 1) + num_tokens1, bonus_token_id)
@triton.jit(do_not_specialize=["vec_len", "max_spec_len"])
def rejection_greedy_sample_triton(
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
target_argmax_ptr, # [num_tokens]
bonus_token_ids_ptr, # [batch_size]
is_greedy_ptr, # [batch_size] or None
vec_len,
max_spec_len,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offset = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offset < vec_len
if is_greedy_ptr is None:
is_greedy_mask = mask
else:
is_greedy = tl.load(is_greedy_ptr + offset, mask=mask, other=0)
is_greedy_mask = mask & (is_greedy != 0)
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offset - 1, 0), is_greedy_mask))
end_idx = tl.load(cu_num_draft_tokens_ptr + offset, is_greedy_mask)
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = get_element(num_draft_tokens, (pos,))
rejected = False
start_idx1 = get_element(start_idx, (pos,))
is_greedy_mask1 = get_element(is_greedy_mask, (pos,))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx1 + i)
target_argmax_id = tl.load(target_argmax_ptr + start_idx1 + i)
tl.store(
output_token_ids_ptr + position * (max_spec_len + 1) + i,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
if not rejected and is_greedy_mask1:
bonus_renew(
bonus_token_ids_ptr,
position,
output_token_ids_ptr,
max_spec_len,
num_tokens1,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_kernel(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
start_idxs = tl.where(
offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask)
)
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = get_element(not_greedy_mask, (req_i,))
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
if not_greedy:
rejected = False
start_idx = get_element(start_idxs, (req_i,))
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
for pos in range(num_draft_tokens):
if not rejected:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
if draft_token_id < 0:
# Invalid draft (e.g., padded).
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
rejected = True
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
else:
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
if draft_prob > 0 and target_prob / draft_prob >= uniform_prob:
# Accept.
token_id = draft_token_id
else:
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
bonus_token_id,
)
@triton.jit(do_not_specialize=["replace_from", "replace_to", "vec_len"])
def expand_kernel(
output_ptr, # [num_tokens]
input_ptr, # [batch_size]
cu_num_tokens_ptr, # [batch_size]
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
offset = req_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
len_mask = offset < vec_len
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
start_idx = tl.where(offset == 0, 0, tl.load(cu_num_tokens_ptr + tl.maximum(offset - 1, 0), len_mask))
end_idx = tl.load(cu_num_tokens_ptr + offset, len_mask)
num_tokens = end_idx - start_idx
src_val = tl.load(input_ptr + offset, len_mask)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = get_element(num_tokens, (i,))
start_idx1 = get_element(start_idx, (i,))
src_val1 = get_element(src_val, (i,))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
@triton.jit
def sample_recovered_tokens_kernel(
output_token_ids_ptr, # [num_tokens]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
SUB_BLOCK: tl.constexpr,
):
req_idx = tl.program_id(0)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
start_idx = 0 if req_idx == 0 else tl.load(cu_num_draft_tokens_ptr + req_idx - 1)
end_idx = tl.load(cu_num_draft_tokens_ptr + req_idx)
num_draft_tokens = end_idx - start_idx
# Early exit for out-of-range positions.
pos = tl.program_id(1)
if pos >= num_draft_tokens:
return
loop = (vocab_size + SUB_BLOCK - 1) // SUB_BLOCK
global_recovered_id = -1
global_max_p = -1.0
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
# Temporarily zero out the probability of the draft token.
# This is essentially the same as target_prob - draft_prob, except that
# n-gram does not have draft_prob. We regard it as 1.
prob = tl.where(vocab_offset == draft_token_id, 0, prob)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
else:
for loop_i in range(loop):
vocab_start = loop_i * SUB_BLOCK
vocab_offset = vocab_start + tl.arange(0, SUB_BLOCK)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=0
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset, mask=vocab_offset < vocab_size, other=float("-inf")
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
tl.store(output_token_ids_ptr + start_idx + pos, global_recovered_id)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
def rejection_greedy_sample_with_triton(
output_token_ids,
num_draft_tokens,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
max_spec_len,
grid,
block_size,
):
vec_len = output_token_ids.shape[0]
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
if min(num_draft_tokens) == 1 and max(num_draft_tokens) == 1 and is_greedy is None:
rejection_greedy_sample_spec_len_1_triton[(grid,)](
output_token_ids,
draft_token_ids,
target_argmax,
bonus_token_ids,
vec_len,
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
BLOCK_SIZE=block_size,
)
else:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
rejection_greedy_sample_triton[(grid,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
is_greedy,
vec_len,
max_spec_len,
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
BLOCK_SIZE=block_size,
)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
def expand_triton(batch_size, expanded_x, x, cu_num_tokens, replace_from, replace_to, max_num_tokens):
vec_len = batch_size
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
grid, block_size = cal_grid_and_block_size(batch_size)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
expand_kernel[(grid,)](
expanded_x,
x,
cu_num_tokens,
replace_from,
replace_to,
vec_len,
MAX_NUM_TOKENS=max_num_tokens, # To avoid recompilation.
feat: implement high-performance Triton kernels for rejection sampling: optimization for rejection_random_sample_kernel (#5259) ### What this PR does / why we need it? This PR introduces optimized Triton implementations for the rejection_random_sample_kernel delivering superior performance compared to the existing Triton implementations. The new Triton kernels maintain full functional accuracy while delivering significant performance improvements across various batch sizes and MTP configurations. ### Does this PR introduce _any_ user-facing change? Yes, this PR modifies rejection_sampler.py to use optimized Triton kernels: rejection_random_sample_kernel is modified and optimized ### How was this patch tested? performance benchmark results: <html xmlns:v="urn:schemas-microsoft-com:vml" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:x="urn:schemas-microsoft-com:office:excel" xmlns="http://www.w3.org/TR/REC-html40"> <head> <meta name=Generator content="Microsoft Excel"> <!--[if !mso]> </head> <body> <!--StartFragment--> Batch Size | MTP | origin implementation(us) | optimized version(us) -- | -- | -- | -- 1 | 1 | 2.934 | 3.64 8 | 1 | 4.467 | 4 32 | 1 | 6.98 | 4.54 64 | 1 | 11.087 | 6.42 128 | 1 | 13.414 | 7.84 256 | 1 | 19.66 | 8.487 512 | 1 | 39.908 | 11.62 1024 | 1 | 81.781 | 18.16 2048 | 1 | 137.923 | 32.934 1 | 2 | 3.4 | 4.02 8 | 2 | 3.74 | 4.24 32 | 2 | 6.373 | 7.394 64 | 2 | 9.747 | 6.46 128 | 2 | 12.98 | 7.76 256 | 2 | 20.834 | 9.787 512 | 2 | 39.314 | 13.56 1024 | 2 | 83.135 | 22.387 2048 | 2 | 157.563 | 40.607 <!--EndFragment--> </body> </html> - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/ad32e3e19ccf0526cb6744a5fed09a138a5fb2f9 Signed-off-by: 1024daniel <xxltju324@gmail.com>
2026-01-05 16:03:02 +08:00
BLOCK_SIZE=block_size,
)
@triton.jit(do_not_specialize=["max_spec_len"])
def rejection_random_sample_block_verify_kernel(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
output_token_ids_ptr, # [batch_size, max_spec_len + 1]
cu_num_draft_tokens_ptr, # [batch_size]
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
bonus_token_ids_ptr, # [batch_size]
recovered_token_ids_ptr, # [num_tokens]
uniform_probs_ptr, # [num_tokens]
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
vec_len,
NO_DRAFT_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < vec_len
is_greedy = tl.load(is_greedy_ptr + offsets, mask, other=1)
not_greedy_mask = is_greedy == 0
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
start_idxs = tl.where(
offsets == 0, 0, tl.load(cu_num_draft_tokens_ptr + tl.maximum(offsets - 1, 0), not_greedy_mask)
)
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
target_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
tmp_uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
uniform_prob = uniform_prob * tmp_uniform_prob
if NO_DRAFT_PROBS:
draft_prob = 1
else:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id)
pi = min(pi * target_prob / draft_prob, 1.0)
if draft_prob > 0 and pi >= uniform_prob:
last_accepted_token_pos = pos
rejected = False
else:
rejected = True
if last_accepted_token_pos > -1:
for pos in range(last_accepted_token_pos + 1):
token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id)
if rejected:
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
recovered_token_id = tl.load(recovered_token_ids_ptr + start_idx + last_accepted_token_pos + 1)
tl.store(
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
output_token_ids_ptr + req_idx * (max_spec_len + 1) + last_accepted_token_pos + 1,
recovered_token_id,
)
else:
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens, bonus_token_id)