[feat][spec decode]Unified draft parallel (#6766)

### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078

### How was this patch tested?

run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811

benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --model /model/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --seed 1234

test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.

----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

cc  @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
kx
2026-03-13 14:07:35 +08:00
committed by GitHub
parent 6ee7ffb98a
commit df1ee8070d
18 changed files with 1943 additions and 311 deletions

View File

@@ -597,6 +597,41 @@ void transpose_kv_cache_by_block(
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
npu_copy_and_expand_eagle_inputs(
const at::Tensor &target_token_ids,
const at::Tensor &target_positions,
const at::Tensor &next_token_ids,
const at::Tensor &query_start_loc,
const at::Tensor &query_end_loc,
int64_t padding_token_id,
int64_t parallel_drafting_token_id,
int64_t num_padding_slots_per_request,
bool shift_input_ids,
int64_t total_draft_tokens)
{
int64_t total_input_tokens = target_token_ids.size(0);
int64_t num_reqs = query_start_loc.size(0) - 1;
auto device = target_token_ids.device();
at::Tensor out_input_ids = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
at::Tensor out_positions = at::empty({total_draft_tokens}, at::dtype(at::kInt).device(device));
at::Tensor out_is_rejected_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
at::Tensor out_is_masked_token_mask = at::empty({total_draft_tokens}, at::dtype(at::kChar).device(device));
at::Tensor out_new_token_indices = at::empty({num_reqs * num_padding_slots_per_request}, at::dtype(at::kInt).device(device));
at::Tensor out_hidden_state_mapping = at::empty({total_input_tokens}, at::dtype(at::kInt).device(device));
EXEC_NPU_CMD(aclnnCopyAndExpandEagleInputs,
target_token_ids, target_positions, next_token_ids, query_start_loc, query_end_loc,
padding_token_id, parallel_drafting_token_id, num_padding_slots_per_request,
shift_input_ids, total_input_tokens,
out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
out_new_token_indices, out_hidden_state_mapping);
return {out_input_ids, out_positions, out_is_rejected_token_mask, out_is_masked_token_mask,
out_new_token_indices, out_hidden_state_mapping};
}
at::Tensor causal_conv1d_fn(
const at::Tensor& mixed_qkv_non_spec_T,
const at::Tensor& conv_weights,
@@ -849,6 +884,16 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"transpose_kv_cache_by_block(Tensor[] kCache, Tensor[] vCache, Tensor blockIDs, int blockSize, int headNum, int headDim, int splitNum, int layerNum) -> ()"
);
ops.impl("transpose_kv_cache_by_block", torch::kPrivateUse1, &vllm_ascend::transpose_kv_cache_by_block);
ops.def(
"npu_copy_and_expand_eagle_inputs(Tensor target_token_ids, Tensor target_positions, "
"Tensor next_token_ids, Tensor query_start_loc, Tensor query_end_loc, "
"int padding_token_id, int parallel_drafting_token_id, int num_padding_slots_per_request, "
"bool shift_input_ids, int total_draft_tokens) -> "
"(Tensor out_input_ids, Tensor out_positions, Tensor out_is_rejected_token_mask, "
"Tensor out_is_masked_token_mask, Tensor out_new_token_indices, Tensor out_hidden_state_mapping)"
);
ops.impl("npu_copy_and_expand_eagle_inputs", torch::kPrivateUse1, &vllm_ascend::npu_copy_and_expand_eagle_inputs);
// causal_conv1d_fn
ops.def(
"causal_conv1d_fn(Tensor mixed_qkv_non_spec_T, "