[feat][spec decode]Unified draft parallel (#6766)
### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078
### How was this patch tested?
run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811
benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
--temperature 0 \
--model /model/Llama-3.1-8B-Instruct \
--backend openai-chat \
--endpoint /v1/chat/completions \
--dataset-name hf \
--dataset-path philschmid/mt-bench \
--num-prompts ${NUM_PROMPTS} \
--max-concurrency ${MAX_CONCURRENCY} \
--seed 1234
test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.
----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1
vllm serve $target \
--tensor-parallel-size 1 \
--max-model-len 4096 \
--no-enable-prefix-caching \
--port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'
cc @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
71
vllm_ascend/spec_decode/draft_proposer.py
Normal file
71
vllm_ascend/spec_decode/draft_proposer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing_extensions import override
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.v1.spec_decode.utils import create_vllm_config_for_draft_model
|
||||
|
||||
from vllm_ascend.spec_decode.eagle_proposer import SpecDecodeBaseProposer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AscendDraftModelProposer(SpecDecodeBaseProposer):
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None,
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
device=device,
|
||||
pass_hidden_states_to_model=False,
|
||||
runner=runner,
|
||||
)
|
||||
self._raise_if_vocab_size_mismatch()
|
||||
self._raise_if_draft_tp_mismatch()
|
||||
|
||||
def _raise_if_vocab_size_mismatch(self):
|
||||
self.speculative_config.verify_equal_vocab_size_if_draft_model()
|
||||
|
||||
def _raise_if_draft_tp_mismatch(self):
|
||||
# Note(Tomas Ruiz) If we run the target model with TP > 1 and
|
||||
# the draft model with TP = 1, then the different TP ranks collide.
|
||||
# Specifically when all ranks compile the draft model on rank 0
|
||||
# (because TP=1), then the torch compile cache is overwritten and corrupted.
|
||||
# We need a mechanism like this: https://github.com/vllm-project/vllm/pull/5414
|
||||
# To prevent this error, we assert that both TP sizes must be the same.
|
||||
spec_cfg = self.speculative_config
|
||||
tgt_tp = spec_cfg.target_parallel_config.tensor_parallel_size
|
||||
draft_tp = spec_cfg.draft_parallel_config.tensor_parallel_size
|
||||
if draft_tp != tgt_tp:
|
||||
raise ValueError(
|
||||
f"Currently, 'draft_tensor_parallel_size' and 'tensor_parallel_size' "
|
||||
f"must be the same. Got {draft_tp} and {tgt_tp}. "
|
||||
"Please pass 'draft_tensor_parallel_size' in the speculative_config."
|
||||
)
|
||||
|
||||
def _get_model(self) -> nn.Module:
|
||||
# Draft models may be quantized or on different parallelism,
|
||||
# so we load them with a modified vllm config
|
||||
from vllm.compilation.backends import set_model_tag
|
||||
|
||||
temp_vllm_config = create_vllm_config_for_draft_model(self.vllm_config)
|
||||
with set_model_tag("draft_model"):
|
||||
model = get_model(
|
||||
vllm_config=temp_vllm_config,
|
||||
prefix="draft_model",
|
||||
)
|
||||
return model
|
||||
|
||||
@override
|
||||
def _maybe_share_embeddings(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share embeddings with the target model
|
||||
pass
|
||||
|
||||
@override
|
||||
def _maybe_share_lm_head(self, target_language_model: nn.Module) -> None:
|
||||
# Draft models don't share lm_head with the target model
|
||||
pass
|
||||
Reference in New Issue
Block a user