[main][bugfix] Fixed the problem of speculative decoding in FULL mode (#7148)
### What this PR does / why we need it?
Fixed the error of speculative decoding in FULL mode when `num_spec + 1`
not in `cudagraph_capture_sizes`.
Now, we can run speculative decoding in FULL mode, but with drafter as
eager.
It depends on https://github.com/vllm-project/vllm-ascend/pull/7144 .
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
Test code is shown as below:
```python
prompts = [
"1.Who are you?",
"2. Who are you?",
]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, top_k=40, max_tokens=200)
llm = LLM(
model="/home/some-model/Meta-Llama-3.1-8B-Instruct",
tensor_parallel_size=1,
max_num_seqs=32,
# enforce_eager=True,
disable_log_stats=False,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
async_scheduling=True,
speculative_config={
"enforce_eager": True,
"model": "/home/some-model/EAGLE3-LLaMA3.1-Instruct-8B",
"disable_padded_drafter_batch": False,
"method": "eagle3",
"num_speculative_tokens": 2,
},
compilation_config={
"cudagraph_mode": "FULL",
"cudagraph_num_of_warmups": 1,
},
max_model_len=4096,
enable_prefix_caching=False,
)
outputs = llm.generate(prompts, sampling_params)
```
The result before:
```text
File "/vllm-workspace/vllm/vllm/v1/cudagraph_dispatcher.py", line 140, in _create_padded_batch_descriptor
assert num_tokens_padded % uniform_decode_query_len == 0
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError
```
The result after:
```text
--------------------------------------------------
total_num_output_tokens: 400
num_drafts: 249
num_draft_tokens: 498
num_accepted_tokens: 149
mean acceptance length: 1.60
--------------------------------------------------
acceptance at token 0: 0.43
acceptance at token 1: 0.17
```
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -89,7 +89,7 @@ class AscendEagleProposer(EagleProposer):
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 1, dtype=torch.int32)
|
||||
self.query_start_loc = self.runner._make_buffer(self.runner.max_num_reqs + 2, dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(self.arange.shape[0], device="cpu", dtype=torch.int32)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
|
||||
@@ -362,7 +362,9 @@ class AscendEagleProposer(EagleProposer):
|
||||
|
||||
model_positions = self._get_positions(num_tokens)
|
||||
|
||||
batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs
|
||||
batch_size = max(
|
||||
num_tokens // (self.num_speculative_tokens + 1), 1
|
||||
) # if not is_profile else self.runner.max_num_reqs
|
||||
if is_profile:
|
||||
batch_size = min(batch_size, self.runner.max_num_reqs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user