Files
xc-llm-ascend/vllm_ascend/ops/triton/spec_decode/utils.py

66 lines
2.7 KiB
Python
Raw Normal View History

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/utils.py
from vllm.triton_utils import tl, triton
@triton.jit(do_not_specialize=["num_reqs"])
def prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
[feat][spec decode]Unified draft parallel (#6766) ### What this PR does / why we need it? Implement a unified parallelized speculative decoding in VLLM Ascend,which can simultaneously support parallel speculative inference schemes such as Pard, P-Eagle, etc. refer to https://github.com/vllm-project/vllm-ascend/pull/6565 and https://github.com/vllm-project/vllm-ascend/pull/4078 ### How was this patch tested? run with parallel drafting script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' base script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 benchmark script: MAX_CONCURRENCY=1 NUM_PROMPTS=80 vllm bench serve --port 8811 \ --temperature 0 \ --model /model/Llama-3.1-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --dataset-path philschmid/mt-bench \ --num-prompts ${NUM_PROMPTS} \ --max-concurrency ${MAX_CONCURRENCY} \ --seed 1234 test results : base(without spec decode): TTFT 79.46ms TPOT 26.99ms output_tokens_throughput 36.75 tok/s this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms output_tokens_throughput 72.98 tok/s per-position acceptance(from position 0 to 7): 79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%. ---------------------------------------------------------------------- run on qwen3 model script : export target=/model/Qwen3-1.7B export draft=/model/PARD-Qwen3-0.6B export CUDA_VISIBLE_DEVICES=1 export ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' cc @NickJudyHvv - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: kx <1670186653@qq.com> Signed-off-by: HF-001 <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
2026-03-13 14:07:35 +08:00
num_rejected_tokens_gpu_ptr,
num_reqs, # tl.int32
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
# Grid-Stride Loop:
block_start_step = num_programs * BLOCK_SIZE
for block_start in tl.range(pid * BLOCK_SIZE, num_reqs, block_start_step):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_reqs
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask)
[Bugfix][0.18.0] fix kernels in sample when mask is not static or draft_token_id is invalid (#8531) <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> The triton kernels in sample encounter some problems, scenarios are shown below: 1. 【expand_kernel/ rejection_random_sample_kernel/ prepare_inputs_padded_kernel】, these three operations will use ‘tl.load(prt + offsets -1, mask)’ in their implementations, but triton compiler reports that the masks in these scenarios are not static and contiguous. As a result, compiler will first access this memory and apply the mask. Therefore, I modified the code to ‘tl.load(prt +tl.maximum(offsets - 1, 0), mask)’ to ensure no -1 reads. 2. 【sample_recovered_tokens_kernel/ rejection_random_sample_kernel】, this kernel uses draft_token_id as an address offset for the load operation. In the PD separation scenario, if the pad token is -1, illegal memory reads and writes can occur. Therefore, i modified the kernel and so they can do well with -1 token. ### Does this PR introduce _any_ user-facing change? <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: ppppeng <zepengliu912@qq.com> Co-authored-by: zepengliu912@qq.com <root@localhost.localdomain>
2026-04-23 23:04:19 +08:00
prev_indices = tl.maximum(offsets - 1, 0)
has_prev = offsets > 0
cu_draft_prev = tl.load(
cu_num_draft_tokens_ptr + prev_indices,
mask=mask & has_prev,
other=0,
)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr)
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask)
num_rejected = num_draft_tokens + 1 - valid_count
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1
index_to_sample = q_last_tok_idx - num_rejected
[Lint]Style: Convert `vllm-ascend/` to ruff format(Batch #12) (#6177) ### What this PR does / why we need it? **Scope of Changes**: | File Path | | :--- | | `vllm_ascend/ops/triton/activation/swiglu_quant.py` | | `vllm_ascend/ops/triton/batch_invariant/matmul.py` | | `vllm_ascend/ops/triton/batch_invariant/mean.py` | | `vllm_ascend/ops/triton/batch_invariant/rmsnorm.py` | | `vllm_ascend/ops/triton/fla/chunk.py` | | `vllm_ascend/ops/triton/fla/chunk_delta_h.py` | | `vllm_ascend/ops/triton/fla/chunk_o.py` | | `vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py` | | `vllm_ascend/ops/triton/fla/cumsum.py` | | `vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py` | | `vllm_ascend/ops/triton/fla/l2norm.py` | | `vllm_ascend/ops/triton/fla/layernorm_guard.py` | | `vllm_ascend/ops/triton/fla/sigmoid_gating.py` | | `vllm_ascend/ops/triton/fla/solve_tril.py` | | `vllm_ascend/ops/triton/fla/utils.py` | | `vllm_ascend/ops/triton/fla/wy_fast.py` | | `vllm_ascend/ops/triton/fused_gdn_gating.py` | | `vllm_ascend/ops/triton/layernorm_gated.py` | | `vllm_ascend/ops/triton/linearnorm/split_qkv_rmsnorm_rope.py` | | `vllm_ascend/ops/triton/mamba/causal_conv1d.py` | | `vllm_ascend/ops/triton/reject_sample.py` | | `vllm_ascend/ops/triton/rope.py` | | `vllm_ascend/ops/triton/spec_decode/utils.py` | | `vllm_ascend/ops/triton/triton_utils.py` | ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.14.0 - vLLM main: https://github.com/vllm-project/vllm/commit/d68209402ddab3f54a09bc1f4de9a9495a283b60 Signed-off-by: MrZ20 <2609716663@qq.com>
2026-01-23 14:59:19 +08:00
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
[feat][spec decode]Unified draft parallel (#6766) ### What this PR does / why we need it? Implement a unified parallelized speculative decoding in VLLM Ascend,which can simultaneously support parallel speculative inference schemes such as Pard, P-Eagle, etc. refer to https://github.com/vllm-project/vllm-ascend/pull/6565 and https://github.com/vllm-project/vllm-ascend/pull/4078 ### How was this patch tested? run with parallel drafting script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' base script: export target=/model/Llama-3.1-8B-Instruct export draft=/model/PARD-Llama-3.2-1B export CUDA_VISIBLE_DEVICES=6 export ASCEND_RT_VISIBLE_DEVICES=6 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 benchmark script: MAX_CONCURRENCY=1 NUM_PROMPTS=80 vllm bench serve --port 8811 \ --temperature 0 \ --model /model/Llama-3.1-8B-Instruct \ --backend openai-chat \ --endpoint /v1/chat/completions \ --dataset-name hf \ --dataset-path philschmid/mt-bench \ --num-prompts ${NUM_PROMPTS} \ --max-concurrency ${MAX_CONCURRENCY} \ --seed 1234 test results : base(without spec decode): TTFT 79.46ms TPOT 26.99ms output_tokens_throughput 36.75 tok/s this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms output_tokens_throughput 72.98 tok/s per-position acceptance(from position 0 to 7): 79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%. ---------------------------------------------------------------------- run on qwen3 model script : export target=/model/Qwen3-1.7B export draft=/model/PARD-Qwen3-0.6B export CUDA_VISIBLE_DEVICES=1 export ASCEND_RT_VISIBLE_DEVICES=1 vllm serve $target \ --tensor-parallel-size 1 \ --max-model-len 4096 \ --no-enable-prefix-caching \ --port 8811 \ --speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}' cc @NickJudyHvv - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/9562912cead1f11e8540fb91306c5cbda66f0007 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: kx <1670186653@qq.com> Signed-off-by: HF-001 <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
2026-03-13 14:07:35 +08:00
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)