diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index ffb7018..21e4e94 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -754,7 +754,7 @@ class TestNPUPlatform(TestBase): self.platform.check_and_update_config(VllmConfig) self.assertTrue( "PIECEWISE compilation enabled on NPU. use_inductor not supported - " - "using only ACL Graph mode" in cm.output[1]) + "using only ACL Graph mode" in cm.output[0]) self.assertEqual( VllmConfig.compilation_config.level, CompilationLevel.PIECEWISE, diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 449c3b0..9da4aa6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -129,36 +129,7 @@ class NPUPlatform(Platform): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config - scheduler_config = vllm_config.scheduler_config ascend_scheduler_config = ascend_config.ascend_scheduler_config - structured_outputs_config = vllm_config.structured_outputs_config - - if (model_config is not None and not model_config.use_mla - and not scheduler_config.async_scheduling - and model_config.runner_type != "pooling"): - logger.info( - "Non-MLA LLMs forcibly disable the chunked prefill feature," - "as the performance of operators supporting this feature " - "functionality is currently suboptimal.") - if not model_config.is_multimodal_model and \ - structured_outputs_config.backend == "auto" and \ - not getattr(scheduler_config, "scheduler_delay_factor", 0) > 0 and \ - not scheduler_config.send_delta_data and \ - scheduler_config.policy == "fcfs": - ascend_scheduler_config.enabled = True - chunked_prefill_enabled_in_ascend_scheduler = getattr( - ascend_scheduler_config, "enable_chunked_prefill", False) - if chunked_prefill_enabled_in_ascend_scheduler: - logger.warning( - "Chunked prefill feature is enabled in ascend_scheduler," - "but note that the operator supporting this feature " - "would lead to performance degradation.") - # In this situation, max_num_batched_tokens would have been rewritten. - # So we must make sure max_num_batched_tokens is not smaller than max_model_len. - if (scheduler_config.max_num_batched_tokens - < scheduler_config.max_model_len - and not chunked_prefill_enabled_in_ascend_scheduler): - scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) @@ -293,11 +264,20 @@ class NPUPlatform(Platform): if cache_config.block_size is None: cache_config.block_size = 128 - if cache_config.enable_prefix_caching and cache_config.block_size != 128: + if cache_config.enable_prefix_caching or \ + not ascend_scheduler_config.enabled or \ + getattr(ascend_scheduler_config, "enable_chunked_prefill", False): logger.warning( - "If prefix caching is enabled, block size must be set to 128." + "If chunked prefill or prefix caching is enabled, block size must be set to 128." ) + origin_block_size = cache_config.block_size cache_config.block_size = 128 + # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. + if model_config and model_config.hf_config.model_type == "qwen3_next": + logger.warning( + "When running qwen3-next model, block_size needs to be restored to its original value." + ) + cache_config.block_size = origin_block_size # Activate custom ops for v1, except on 310P if not is_310p():