diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ead030c2..ee5dd81b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -2958,7 +2958,8 @@ class NPUModelRunner(GPUModelRunner): attention_backends: list[set[type[AttentionBackend]]], kv_cache_groups: list[KVCacheGroupSpec], ) -> None: - super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups) + with update_pass_config(self): + super()._check_and_update_cudagraph_mode(attention_backends, kv_cache_groups) # NOTE: Since aclgraph_batch_sizes cannot be determined until here, # we set the graph params right before initializing the keys. @@ -3061,3 +3062,14 @@ def _replace_gpu_model_runner_function_wrapper(target_module_name): yield finally: setattr(target_module, "graph_capture", graph_capture) # noqa: B010 + + +# TODO: remove it when flash_comm1 is removed +@contextmanager +def update_pass_config(model_runner): + try: + original_pass_config_sp = model_runner.compilation_config.pass_config.enable_sp + model_runner.compilation_config.pass_config.enable_sp = enable_sp(model_runner.vllm_config) + yield + finally: + model_runner.compilation_config.pass_config.enable_sp = original_pass_config_sp