[feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078
### How was this patch tested?
run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811
benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
--temperature 0 \
--model /model/Llama-3.1-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts ${NUM_PROMPTS} \
--max-concurrency ${MAX_CONCURRENCY} \
--seed 1234
test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.
----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
cc @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
@@ -108,6 +108,7 @@ from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.draft_proposer import AscendDraftModelProposer
|
||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||
from vllm_ascend.spec_decode.medusa_proposer import AscendMedusaProposer
|
||||
from vllm_ascend.spec_decode.ngram_proposer import AscendNgramProposer
|
||||
@@ -406,7 +407,12 @@ class NPUModelRunner(GPUModelRunner):
|
||||
def _set_up_drafter(self):
|
||||
# Set up speculative decoding.
|
||||
self.drafter: (
|
||||
AscendNgramProposer | AscendEagleProposer | AscendSuffixDecodingProposer | AscendMedusaProposer | None
|
||||
AscendNgramProposer
|
||||
| AscendEagleProposer
|
||||
| AscendDraftModelProposer
|
||||
| AscendSuffixDecodingProposer
|
||||
| AscendMedusaProposer
|
||||
| None
|
||||
) = None
|
||||
self.actual_seq_lengths_q: list[int] = []
|
||||
self.decode_token_per_req = 1
|
||||
@@ -971,7 +977,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
draft_token_ids = self.drafter.propose(
|
||||
valid_sampled_token_ids, sampling_metadata, spec_decode_metadata, sample_hidden_states
|
||||
)
|
||||
elif self.speculative_config.use_eagle():
|
||||
elif self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model():
|
||||
common_attn_metadata = spec_decode_common_attn_metadata
|
||||
sampled_token_ids = valid_sampled_token_ids
|
||||
|
||||
@@ -1018,6 +1024,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
long_seq_metadata = None # type: ignore
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
|
||||
num_rejected_tokens_gpu = None
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
@@ -1053,8 +1061,10 @@ class NPUModelRunner(GPUModelRunner):
|
||||
)
|
||||
else:
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices, token_indices_to_sample = self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
common_attn_metadata, token_indices, token_indices_to_sample, num_rejected_tokens_gpu = (
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata, spec_decode_metadata, valid_sampled_tokens_count
|
||||
)
|
||||
)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
@@ -1075,7 +1085,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
token_indices_to_sample=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
@@ -1084,6 +1094,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
num_rejected_tokens_gpu=num_rejected_tokens_gpu,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown speculative decoding method: {self.speculative_config.method}")
|
||||
@@ -1516,16 +1527,16 @@ class NPUModelRunner(GPUModelRunner):
|
||||
|
||||
with record_function_or_nullcontext("draft_token"):
|
||||
if self.speculative_config:
|
||||
use_padded_batch_for_eagle = (
|
||||
use_padded_batch = (
|
||||
self.speculative_config
|
||||
and self.speculative_config.use_eagle()
|
||||
and (self.speculative_config.use_eagle() or self.speculative_config.uses_draft_model())
|
||||
and not self.speculative_config.disable_padded_drafter_batch
|
||||
)
|
||||
if use_padded_batch_for_eagle:
|
||||
if use_padded_batch:
|
||||
# EAGLE speculative decoding can use the GPU sampled tokens
|
||||
# as inputs, and does not need to wait for bookkeeping to finish.
|
||||
propose_draft_token_ids(sampler_output.sampled_token_ids)
|
||||
if self.speculative_config and not use_padded_batch_for_eagle:
|
||||
if self.speculative_config and not use_padded_batch:
|
||||
# ngram and other speculative decoding methods use the sampled
|
||||
# tokens on the CPU, so they are run after bookkeeping.
|
||||
propose_draft_token_ids(valid_sampled_token_ids)
|
||||
@@ -2165,7 +2176,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
if kv_cache_gid > 0:
|
||||
cm.block_table_tensor, cm.slot_mapping = _get_block_table_and_slot_mapping(kv_cache_gid)
|
||||
if self.speculative_config and spec_decode_common_attn_metadata is None:
|
||||
if isinstance(self.drafter, AscendEagleProposer):
|
||||
if isinstance(self.drafter, AscendEagleProposer | AscendDraftModelProposer):
|
||||
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
|
||||
spec_decode_common_attn_metadata = cm
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user