From fb0d6dd17552f1d77b68b5b52331c0911a872e73 Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Thu, 12 Mar 2026 14:51:12 +0800 Subject: [PATCH] [main][bugfix] Fixed the problem of speculative decoding in FULL mode (#7148) ### What this PR does / why we need it? Fixed the error of speculative decoding in FULL mode when `num_spec + 1` not in `cudagraph_capture_sizes`. Now, we can run speculative decoding in FULL mode, but with drafter as eager. It depends on https://github.com/vllm-project/vllm-ascend/pull/7144 . ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? Test code is shown as below: ```python prompts = [ "1.Who are you?", "2. Who are you?", ] sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=200) llm = LLM( model="/home/some-model/Meta-Llama-3.1-8B-Instruct", tensor_parallel_size=1, max_num_seqs=32, # enforce_eager=True, disable_log_stats=False, distributed_executor_backend="mp", gpu_memory_utilization=0.7, async_scheduling=True, speculative_config={ "enforce_eager": True, "model": "/home/some-model/EAGLE3-LLaMA3.1-Instruct-8B", "disable_padded_drafter_batch": False, "method": "eagle3", "num_speculative_tokens": 2, }, compilation_config={ "cudagraph_mode": "FULL", "cudagraph_num_of_warmups": 1, }, max_model_len=4096, enable_prefix_caching=False, ) outputs = llm.generate(prompts, sampling_params) ``` The result before: ```text File "/vllm-workspace/vllm/vllm/v1/cudagraph_dispatcher.py", line 140, in _create_padded_batch_descriptor assert num_tokens_padded % uniform_decode_query_len == 0 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ AssertionError ``` The result after: ```text -------------------------------------------------- total_num_output_tokens: 400 num_drafts: 249 num_draft_tokens: 498 num_accepted_tokens: 149 mean acceptance length: 1.60 -------------------------------------------------- acceptance at token 0: 0.43 acceptance at token 1: 0.17 ``` - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: drslark --- vllm_ascend/patch/__init__.py | 14 ++++++++ vllm_ascend/patch/worker/__init__.py | 1 + vllm_ascend/patch/worker/patch_cudagraph.py | 38 +++++++++++++++++++++ vllm_ascend/spec_decode/eagle_proposer.py | 6 ++-- 4 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_cudagraph.py diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index d0d89846..b4d4148d 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -438,3 +438,17 @@ # patch Qwen3_5GatedDeltaNet._forward_core to use triton ops like `fused_recurrent_gated_delta_rule`. # Future Plan: # Remove this patch when all ops in _forward_core support both Qwen3_5 and Qwen3Next. +# +# ** 20. File: worker/patch_cudagraph.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.v1.cudagraph_dispatcher.CudagraphDispatcher._create_padded_batch_descriptor` +# Why: +# vllm's FULL mode will cause error, we use a patch to avoid it. +# After that, FULL can be enable now. +# How: +# Dynamically replace the `_create_padded_batch_descriptor` function at runtime, +# and change the condition of if. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/34880 +# Future Plan: +# Remove this patch when vLLM merges the PR. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 2493ecb2..1664c84a 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -43,3 +43,4 @@ import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa import vllm_ascend.patch.worker.patch_draft_quarot # noqa +import vllm_ascend.patch.worker.patch_cudagraph # noqa diff --git a/vllm_ascend/patch/worker/patch_cudagraph.py b/vllm_ascend/patch/worker/patch_cudagraph.py new file mode 100644 index 00000000..bff74fdd --- /dev/null +++ b/vllm_ascend/patch/worker/patch_cudagraph.py @@ -0,0 +1,38 @@ +from vllm.config import CUDAGraphMode +from vllm.forward_context import BatchDescriptor +from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher + + +def _create_padded_batch_descriptor( + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, +) -> BatchDescriptor: + max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs + uniform_decode_query_len = self.uniform_decode_query_len + num_tokens_padded = self._bs_to_padded_graph_size[num_tokens] + + # FULL mode should not be treated as uniform decode + if ( + uniform_decode + and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL) + and self.cudagraph_mode != CUDAGraphMode.FULL + ): + num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs) + assert num_tokens_padded % uniform_decode_query_len == 0 + else: + uniform_decode = False + num_reqs = min(num_tokens_padded, max_num_seqs) + + return BatchDescriptor( + num_tokens=num_tokens_padded, + num_reqs=num_reqs, + uniform=uniform_decode, + has_lora=has_lora, + num_active_loras=num_active_loras, + ) + + +CudagraphDispatcher._create_padded_batch_descriptor = _create_padded_batch_descriptor diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index b821524a..b8e0c2f9 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -89,7 +89,7 @@ class AscendEagleProposer(EagleProposer): self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling self.decode_threshold = 1 + self.num_speculative_tokens - self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32) + self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32) self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32) self.attn_mask_builder = AttentionMaskBuilder(self.device) @@ -362,7 +362,9 @@ class AscendEagleProposer(EagleProposer): model_positions = self._get_positions(num_tokens) - batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs + batch_size = max( + num_tokens // (self.num_speculative_tokens + 1), 1 + ) # if not is_profile else self.runner.max_num_reqs if is_profile: batch_size = min(batch_size, self.runner.max_num_reqs)