### What this PR does / why we need it?
Fixed the error of speculative decoding in FULL mode when `num_spec + 1`
not in `cudagraph_capture_sizes`.
Now, we can run speculative decoding in FULL mode, but with drafter as
eager.
It depends on https://github.com/vllm-project/vllm-ascend/pull/7144 .
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Test code is shown as below:
```python
prompts = [
"1.Who are you?",
"2. Who are you?",
]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=200)
llm = LLM(
model="/home/some-model/Meta-Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
max_num_seqs=32,
# enforce_eager=True,
disable_log_stats=False,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
async_scheduling=True,
speculative_config={
"enforce_eager": True,
"model": "/home/some-model/EAGLE3-LLaMA3.1-Instruct-8B",
"disable_padded_drafter_batch": False,
"method": "eagle3",
"num_speculative_tokens": 2,
},
compilation_config={
"cudagraph_mode": "FULL",
"cudagraph_num_of_warmups": 1,
},
max_model_len=4096,
enable_prefix_caching=False,
)
outputs = llm.generate(prompts, sampling_params)
```
The result before:
```text
File "/vllm-workspace/vllm/vllm/v1/cudagraph_dispatcher.py", line 140, in _create_padded_batch_descriptor
assert num_tokens_padded % uniform_decode_query_len == 0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
```
The result after:
```text
--------------------------------------------------
total_num_output_tokens: 400
num_drafts: 249
num_draft_tokens: 498
num_accepted_tokens: 149
mean acceptance length: 1.60
--------------------------------------------------
acceptance at token 0: 0.43
acceptance at token 1: 0.17
```
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
Signed-off-by: drslark <slarksblood@qq.com>
39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
from vllm.config import CUDAGraphMode
|
|
from vllm.forward_context import BatchDescriptor
|
|
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
|
|
|
|
|
def _create_padded_batch_descriptor(
|
|
self,
|
|
num_tokens: int,
|
|
uniform_decode: bool,
|
|
has_lora: bool,
|
|
num_active_loras: int = 0,
|
|
) -> BatchDescriptor:
|
|
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
|
|
uniform_decode_query_len = self.uniform_decode_query_len
|
|
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]
|
|
|
|
# FULL mode should not be treated as uniform decode
|
|
if (
|
|
uniform_decode
|
|
and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL)
|
|
and self.cudagraph_mode != CUDAGraphMode.FULL
|
|
):
|
|
num_reqs = min(num_tokens_padded // uniform_decode_query_len, max_num_seqs)
|
|
assert num_tokens_padded % uniform_decode_query_len == 0
|
|
else:
|
|
uniform_decode = False
|
|
num_reqs = min(num_tokens_padded, max_num_seqs)
|
|
|
|
return BatchDescriptor(
|
|
num_tokens=num_tokens_padded,
|
|
num_reqs=num_reqs,
|
|
uniform=uniform_decode,
|
|
has_lora=has_lora,
|
|
num_active_loras=num_active_loras,
|
|
)
|
|
|
|
|
|
CudagraphDispatcher._create_padded_batch_descriptor = _create_padded_batch_descriptor
|