Files
xc-llm-ascend/vllm_ascend/ops/triton/spec_decode/utils.py
kx df1ee8070d [feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078

### How was this patch tested?

run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811

benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --model /model/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --seed 1234

test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.

----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

cc  @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
2026-03-13 14:07:35 +08:00

66 lines
2.6 KiB
Python

# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/v1/spec_decode/utils.py
from vllm.triton_utils import tl, triton
@triton.jit
def prepare_inputs_padded_kernel(
cu_num_draft_tokens_ptr, # [num_reqs]
valid_sampled_tokens_count_ptr, # [num_reqs]
query_start_loc_gpu_ptr, # [num_reqs + 1]
token_indices_to_sample_ptr, # [num_reqs] (output)
num_rejected_tokens_gpu_ptr,
num_reqs, # tl.int32
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_programs = tl.num_programs(axis=0)
# Grid-Stride Loop:
block_start_step = num_programs * BLOCK_SIZE
for block_start in tl.range(pid * BLOCK_SIZE, num_reqs, block_start_step):
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_reqs
# Calculate num_draft_tokens from cu_num_draft_tokens, which is an inclusive
# cumulative sum (first entry is the first value, not zero).
cu_draft_curr = tl.load(cu_num_draft_tokens_ptr + offsets, mask=mask)
prev_indices = offsets - 1
has_prev = offsets > 0
cu_draft_prev = tl.load(
cu_num_draft_tokens_ptr + prev_indices,
mask=mask & has_prev,
other=0,
)
num_draft_tokens = tl.where(has_prev, cu_draft_curr - cu_draft_prev, cu_draft_curr)
valid_count = tl.load(valid_sampled_tokens_count_ptr + offsets, mask=mask)
num_rejected = num_draft_tokens + 1 - valid_count
num_rejected = tl.where(num_draft_tokens > 0, num_rejected, 0)
# query_start_loc[req_idx + 1] is the start position of the next request,
# which is one past the last token of this request.
q_last_tok_idx = tl.load(query_start_loc_gpu_ptr + offsets + 1, mask=mask) - 1
index_to_sample = q_last_tok_idx - num_rejected
tl.store(token_indices_to_sample_ptr + offsets, index_to_sample, mask=mask)
tl.store(num_rejected_tokens_gpu_ptr + offsets, num_rejected, mask=mask)