[bugfix] fix capture shape in sp_eagle_fullgraph (#6846)

### What this PR does / why we need it?

This was meant to be merged in #6536, but I accidentally restored a
commit. You can find the relevant discussion
[here](https://github.com/vllm-project/vllm-ascend/pull/6536#issuecomment-3882883471).

Since `self.pass_config.enable_sp` is forcibly set to `False` in the
[source
code](f176443446/vllm/config/compilation.py (L1066)),
this section will no longer verify whether the generated cudagraph
shapes are multiples of both `uniform_decode_query_len`
(`num_speculative_tokens + 1`) and `tensor_parallel_size`.

This PR enables the `num_speculative_tokens + 1` and
`tensor_parallel_size` check upfront. Therefore, it won't silently round
up the `cudagraph_size` and throw a cryptic error for the user.

A typical example of this cryptic error looks like:
```
ValueError: could not broadcast input array from shape (196,) into shape (14,)
```

### Does this PR introduce _any_ user-facing change?

no.

### How was this patch tested?

Have passed all test.

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
Signed-off-by: guozr <guozr1997@hotmail.com>
Co-authored-by: lilinsiman <lilinsiman@gmail.com>
Co-authored-by: drslark <slarksblood@qq.com>
Co-authored-by: guozr <guozr1997@hotmail.com>
This commit is contained in:
starmountain1997
2026-02-28 17:30:02 +08:00
committed by GitHub
parent 81fb7d5779
commit 5ffae03156

View File

@@ -2958,7 +2958,8 @@ class NPUModelRunner(GPUModelRunner):
attention_backends: list[set[type[AttentionBackend]]],
kv_cache_groups: list[KVCacheGroupSpec],
) -> None:
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)
with update_pass_config(self):
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
# we set the graph params right before initializing the keys.
@@ -3061,3 +3062,14 @@ def _replace_gpu_model_runner_function_wrapper(target_module_name):
yield
finally:
setattr(target_module, "graph_capture", graph_capture) # noqa: B010
# TODO: remove it when flash_comm1 is removed
@contextmanager
def update_pass_config(model_runner):
try:
original_pass_config_sp = model_runner.compilation_config.pass_config.enable_sp
model_runner.compilation_config.pass_config.enable_sp = enable_sp(model_runner.vllm_config)
yield
finally:
model_runner.compilation_config.pass_config.enable_sp = original_pass_config_sp