From 585a494baa4bdbce5a71ef6808033466ed9f90f3 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Fri, 12 Sep 2025 23:17:09 +0800 Subject: [PATCH] [Core] Disable the chunked prefill feature in Non-MLA LLMs (#2894) ### What this PR does / why we need it? This PR enforces the forcible disabling of the chunked prefill feature in Non-MLA models, as the performance of operators supporting this functionality is currently suboptimal. Unless the user has enabled chunked prefill in the ascend_scheduler_config, we would allow this feature. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. Related: https://github.com/vllm-project/vllm-ascend/pull/2659 - vLLM version: main - vLLM main: https://github.com/vllm-project/vllm/commit/d21a36f5f9569949dd6c313deed609d77e393850 Signed-off-by: rjg-lyh <1318825571@qq.com> --- tests/ut/core/test_schedule_config.py | 18 ----------------- vllm_ascend/core/schedule_config.py | 5 ----- vllm_ascend/platform.py | 29 +++++++++++++++++++++++++++ 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index 8f1422f..b135370 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -36,7 +36,6 @@ class TestAscendSchedulerConfig(TestBase): self.basic_scheduler_config, {}) self.assertEqual(ascend_config.enable_chunked_prefill, False) self.assertEqual(ascend_config.policy, "fcfs") - self.assertEqual(ascend_config.num_scheduler_steps, 1) self.assertEqual(ascend_config.scheduler_cls, "vllm_ascend.core.scheduler.AscendScheduler") self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192) @@ -49,7 +48,6 @@ class TestAscendSchedulerConfig(TestBase): AscendSchedulerConfig( enable_chunked_prefill=False, policy="fcfs", - num_scheduler_steps=1, scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler", max_num_batched_tokens=2048, max_model_len=2048, @@ -57,7 +55,6 @@ class TestAscendSchedulerConfig(TestBase): ) self.assertEqual(ascend_config.enable_chunked_prefill, False) self.assertEqual(ascend_config.policy, "fcfs") - self.assertEqual(ascend_config.num_scheduler_steps, 1) self.assertEqual(ascend_config.scheduler_cls, "vllm_ascend.core.scheduler.AscendScheduler") self.assertEqual(ascend_config.max_num_batched_tokens, 2048) @@ -85,21 +82,6 @@ class TestAscendSchedulerConfig(TestBase): self.assertIn("currently AscendScheduler only supports LLM models", str(context.exception)) - def test_not_implemented_multi_step(self): - with self.assertRaises(NotImplementedError) as context: - AscendSchedulerConfig.initialize_from_config( - self.basic_scheduler_config, - AscendSchedulerConfig( - num_scheduler_steps=2, - max_num_batched_tokens=2048, - max_model_len=2048, - ), - ) - self.assertIn( - "currently AscendScheduler doesn't support multi-step", - str(context.exception), - ) - def test_not_implemented_send_delta_data(self): with self.assertRaises(NotImplementedError) as context: AscendSchedulerConfig.initialize_from_config( diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 422ca9a..257657a 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -25,7 +25,6 @@ from vllm.config import SchedulerConfig class AscendSchedulerConfig(SchedulerConfig): enable_chunked_prefill: bool = False policy: str = "fcfs" - num_scheduler_steps: int = 1 scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") enable_pd_transfer: bool = False @@ -44,7 +43,6 @@ class AscendSchedulerConfig(SchedulerConfig): # Override default values into original SchedulerConfig scheduler_config["enable_chunked_prefill"] = False scheduler_config["policy"] = "fcfs" - scheduler_config["num_scheduler_steps"] = 1 scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") scheduler_config["enable_pd_transfer"] = False @@ -76,9 +74,6 @@ class AscendSchedulerConfig(SchedulerConfig): if self.is_multimodal_model: raise NotImplementedError( "currently AscendScheduler only supports LLM models.") - if self.num_scheduler_steps > 1: - raise NotImplementedError( - "currently AscendScheduler doesn't support multi-step.") if self.send_delta_data: raise NotImplementedError( "currently AscendScheduler doesn't support send_delta_data.") diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index ee80c7d..e91f2cd 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -128,6 +128,35 @@ class NPUPlatform(Platform): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config + decoding_config = vllm_config.decoding_config + scheduler_config = vllm_config.scheduler_config + ascend_scheduler_config = ascend_config.ascend_scheduler_config + + if model_config is not None and not model_config.use_mla: + logger.info( + "Non-MLA LLMs forcibly disable the chunked prefill feature," + "as the performance of operators supporting this feature " + "functionality is currently suboptimal.") + if not model_config.is_multimodal_model and \ + decoding_config.backend == "auto" and \ + not scheduler_config.delay_factor > 0 and \ + not scheduler_config.send_delta_data and \ + scheduler_config.policy == "fcfs": + ascend_scheduler_config.enabled = True + chunked_prefill_enabled_in_ascend_scheduler = getattr( + ascend_scheduler_config, "enable_chunked_prefill", False) + if chunked_prefill_enabled_in_ascend_scheduler: + logger.warning( + "Chunked prefill feature is enabled in ascend_scheduler," + "but note that the operator supporting this feature " + "would lead to performance degradation.") + # In this situation, max_num_batched_tokens would have been rewritten. + # So we must make sure max_num_batched_tokens is not smaller than max_model_len. + if (scheduler_config.max_num_batched_tokens + < scheduler_config.max_model_len + and not chunked_prefill_enabled_in_ascend_scheduler): + scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len + kv_cache_dtype = vllm_config.additional_config.get( "kv_cache_dtype", None) if kv_cache_dtype is not None: