[bugfix] fix capture shape in sp_eagle_fullgraph (#6846)
### What this PR does / why we need it? This was meant to be merged in #6536, but I accidentally restored a commit. You can find the relevant discussion [here](https://github.com/vllm-project/vllm-ascend/pull/6536#issuecomment-3882883471). Since `self.pass_config.enable_sp` is forcibly set to `False` in the [source code](f176443446/vllm/config/compilation.py (L1066)), this section will no longer verify whether the generated cudagraph shapes are multiples of both `uniform_decode_query_len` (`num_speculative_tokens + 1`) and `tensor_parallel_size`. This PR enables the `num_speculative_tokens + 1` and `tensor_parallel_size` check upfront. Therefore, it won't silently round up the `cudagraph_size` and throw a cryptic error for the user. A typical example of this cryptic error looks like: ``` ValueError: could not broadcast input array from shape (196,) into shape (14,) ``` ### Does this PR introduce _any_ user-facing change? no. ### How was this patch tested? Have passed all test. - vLLM version: v0.15.0 - vLLM main:83b47f67b1--------- Signed-off-by: lilinsiman <lilinsiman@gmail.com> Signed-off-by: guozr <guozr1997@hotmail.com> Co-authored-by: lilinsiman <lilinsiman@gmail.com> Co-authored-by: drslark <slarksblood@qq.com> Co-authored-by: guozr <guozr1997@hotmail.com>
This commit is contained in:
@@ -2958,7 +2958,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
attention_backends: list[set[type[AttentionBackend]]],
|
||||
kv_cache_groups: list[KVCacheGroupSpec],
|
||||
) -> None:
|
||||
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)
|
||||
with update_pass_config(self):
|
||||
super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups)
|
||||
|
||||
# NOTE: Since aclgraph_batch_sizes cannot be determined until here,
|
||||
# we set the graph params right before initializing the keys.
|
||||
@@ -3061,3 +3062,14 @@ def _replace_gpu_model_runner_function_wrapper(target_module_name):
|
||||
yield
|
||||
finally:
|
||||
setattr(target_module, "graph_capture", graph_capture) # noqa: B010
|
||||
|
||||
|
||||
# TODO: remove it when flash_comm1 is removed
|
||||
@contextmanager
|
||||
def update_pass_config(model_runner):
|
||||
try:
|
||||
original_pass_config_sp = model_runner.compilation_config.pass_config.enable_sp
|
||||
model_runner.compilation_config.pass_config.enable_sp = enable_sp(model_runner.vllm_config)
|
||||
yield
|
||||
finally:
|
||||
model_runner.compilation_config.pass_config.enable_sp = original_pass_config_sp
|
||||
|
||||
Reference in New Issue
Block a user