diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 324b693a..0034a653 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -189,12 +189,6 @@ class NPUPlatform(Platform): return min(max_num_seqs * decode_query_len, 512) - @classmethod - def apply_config_platform_defaults(cls, vllm_config: VllmConfig) -> None: - default_max_cg_capture_size = cls._get_default_max_cudagraph_capture_size(vllm_config) - if default_max_cg_capture_size is not None: - vllm_config.compilation_config.max_cudagraph_capture_size = default_max_cg_capture_size - @classmethod def get_device_capability(cls, device_id: int = 0): return None @@ -209,6 +203,10 @@ class NPUPlatform(Platform): pass_config.sp_min_token_num = get_sp_min_token_num(vllm_config) logger.info(f"set sp_min_token_num to {pass_config.sp_min_token_num}") + default_max_cg_capture_size = cls._get_default_max_cudagraph_capture_size(vllm_config) + if default_max_cg_capture_size is not None: + vllm_config.compilation_config.max_cudagraph_capture_size = default_max_cg_capture_size + @classmethod def get_device_name(cls, device_id: int = 0) -> str: return torch.npu.get_device_name(device_id)