diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4a6bf06f..206f40df 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3298,12 +3298,19 @@ class NPUModelRunner(GPUModelRunner): with update_pass_config(self): super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups) + capture_descs = self.cudagraph_dispatcher.get_capture_descs() + capture_sizes = sorted({ + desc.num_tokens + for _, descs in capture_descs + for desc in descs + }) + # NOTE: Since aclgraph_batch_sizes cannot be determined until here, # we set the graph params right before initializing the keys. if self.use_aclgraph: - set_graph_params(self.cudagraph_batch_sizes) + set_graph_params(capture_sizes) if self.speculative_config: - set_draft_graph_params(self.cudagraph_batch_sizes) + set_draft_graph_params(capture_sizes) def capture_model(self) -> None: gpu_model_runner_cls = next((cls for cls in self.__class__.__mro__ if cls.__name__ == "GPUModelRunner"), None)