[main][bugfix] Fixed the problem of speculative decoding in FULL mode (#7148)

### What this PR does / why we need it?

Fixed the error of speculative decoding in FULL mode when `num_spec + 1`
not in `cudagraph_capture_sizes`.

Now, we can run speculative decoding in FULL mode, but with drafter as
eager.

It depends on https://github.com/vllm-project/vllm-ascend/pull/7144 .

### Does this PR introduce _any_ user-facing change?

N/A

### How was this patch tested?

Test code is shown as below:

```python
prompts = [
    "1.Who are you?",
    "2. Who are you?",
]

sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=200)
llm = LLM(
    model="/home/some-model/Meta-Llama-3.1-8B-Instruct",
    tensor_parallel_size=1,
    max_num_seqs=32,
    # enforce_eager=True,
    disable_log_stats=False,
    distributed_executor_backend="mp",
    gpu_memory_utilization=0.7,
    async_scheduling=True,

    speculative_config={
        "enforce_eager": True,
        "model": "/home/some-model/EAGLE3-LLaMA3.1-Instruct-8B",
        "disable_padded_drafter_batch": False,
        "method": "eagle3",
        "num_speculative_tokens": 2,
    },
    
    compilation_config={
        "cudagraph_mode": "FULL",
        "cudagraph_num_of_warmups": 1,
    },

    max_model_len=4096, 
    enable_prefix_caching=False,
)

outputs = llm.generate(prompts, sampling_params)
```

The result before:

```text
   File "/vllm-workspace/vllm/vllm/v1/cudagraph_dispatcher.py", line 140, in _create_padded_batch_descriptor
     assert num_tokens_padded % uniform_decode_query_len == 0
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 AssertionError
```

The result after:

```text
--------------------------------------------------
total_num_output_tokens: 400
num_drafts: 249
num_draft_tokens: 498
num_accepted_tokens: 149
mean acceptance length: 1.60
--------------------------------------------------
acceptance at token 0: 0.43
acceptance at token 1: 0.17
```

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-03-12 14:51:12 +08:00
committed by GitHub
parent 37d1bd8c50
commit fb0d6dd175
4 changed files with 57 additions and 2 deletions

View File

@@ -438,3 +438,17 @@
# patch Qwen3_5GatedDeltaNet._forward_core to use triton ops like `fused_recurrent_gated_delta_rule`.
# Future Plan:
# Remove this patch when all ops in _forward_core support both Qwen3_5 and Qwen3Next.
#
# ** 20. File: worker/patch_cudagraph.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.cudagraph_dispatcher.CudagraphDispatcher._create_padded_batch_descriptor`
# Why:
# vllm's FULL mode will cause error, we use a patch to avoid it.
# After that, FULL can be enable now.
# How
# Dynamically replace the `_create_padded_batch_descriptor` function at runtime,
# and change the condition of if.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/34880
# Future Plan:
# Remove this patch when vLLM merges the PR.

View File

@@ -43,3 +43,4 @@ import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
import vllm_ascend.patch.worker.patch_draft_quarot # noqa
import vllm_ascend.patch.worker.patch_cudagraph # noqa

View File

@@ -0,0 +1,38 @@
from vllm.config import CUDAGraphMode
from vllm.forward_context import BatchDescriptor
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
def _create_padded_batch_descriptor(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
num_active_loras: int = 0,
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
# FULL mode should not be treated as uniform decode
if (
uniform_decode
and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL)
and self.cudagraph_mode != CUDAGraphMode.FULL
):
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
assert num_tokens_padded % uniform_decode_query_len == 0
else:
uniform_decode = False
num_reqs = min(num_tokens_padded, max_num_seqs)
return BatchDescriptor(
num_tokens=num_tokens_padded,
num_reqs=num_reqs,
uniform=uniform_decode,
has_lora=has_lora,
num_active_loras=num_active_loras,
)
CudagraphDispatcher._create_padded_batch_descriptor = _create_padded_batch_descriptor

View File

@@ -89,7 +89,7 @@ class AscendEagleProposer(EagleProposer):
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
self.decode_threshold = 1 + self.num_speculative_tokens
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32)
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32)
self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
@@ -362,7 +362,9 @@ class AscendEagleProposer(EagleProposer):
model_positions = self._get_positions(num_tokens)
batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs
batch_size = max(
num_tokens // (self.num_speculative_tokens + 1), 1
) # if not is_profile else self.runner.max_num_reqs
if is_profile:
batch_size = min(batch_size, self.runner.max_num_reqs)