[Core] Disable the chunked prefill feature in Non-MLA LLMs (#2894)
### What this PR does / why we need it?
This PR enforces the forcible disabling of the chunked prefill feature
in Non-MLA models, as the performance of operators supporting this
functionality is currently suboptimal. Unless the user has enabled
chunked prefill in the ascend_scheduler_config, we would allow this
feature.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
CI passed with new added/existing test.
Related: https://github.com/vllm-project/vllm-ascend/pull/2659
- vLLM version: main
- vLLM main:
d21a36f5f9
Signed-off-by: rjg-lyh <1318825571@qq.com>
This commit is contained in:
@@ -36,7 +36,6 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
self.basic_scheduler_config, {})
|
self.basic_scheduler_config, {})
|
||||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||||
self.assertEqual(ascend_config.policy, "fcfs")
|
self.assertEqual(ascend_config.policy, "fcfs")
|
||||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
|
||||||
self.assertEqual(ascend_config.scheduler_cls,
|
self.assertEqual(ascend_config.scheduler_cls,
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||||
@@ -49,7 +48,6 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
AscendSchedulerConfig(
|
AscendSchedulerConfig(
|
||||||
enable_chunked_prefill=False,
|
enable_chunked_prefill=False,
|
||||||
policy="fcfs",
|
policy="fcfs",
|
||||||
num_scheduler_steps=1,
|
|
||||||
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
|
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
|
||||||
max_num_batched_tokens=2048,
|
max_num_batched_tokens=2048,
|
||||||
max_model_len=2048,
|
max_model_len=2048,
|
||||||
@@ -57,7 +55,6 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||||
self.assertEqual(ascend_config.policy, "fcfs")
|
self.assertEqual(ascend_config.policy, "fcfs")
|
||||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
|
||||||
self.assertEqual(ascend_config.scheduler_cls,
|
self.assertEqual(ascend_config.scheduler_cls,
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
|
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
|
||||||
@@ -85,21 +82,6 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
self.assertIn("currently AscendScheduler only supports LLM models",
|
self.assertIn("currently AscendScheduler only supports LLM models",
|
||||||
str(context.exception))
|
str(context.exception))
|
||||||
|
|
||||||
def test_not_implemented_multi_step(self):
|
|
||||||
with self.assertRaises(NotImplementedError) as context:
|
|
||||||
AscendSchedulerConfig.initialize_from_config(
|
|
||||||
self.basic_scheduler_config,
|
|
||||||
AscendSchedulerConfig(
|
|
||||||
num_scheduler_steps=2,
|
|
||||||
max_num_batched_tokens=2048,
|
|
||||||
max_model_len=2048,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.assertIn(
|
|
||||||
"currently AscendScheduler doesn't support multi-step",
|
|
||||||
str(context.exception),
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_not_implemented_send_delta_data(self):
|
def test_not_implemented_send_delta_data(self):
|
||||||
with self.assertRaises(NotImplementedError) as context:
|
with self.assertRaises(NotImplementedError) as context:
|
||||||
AscendSchedulerConfig.initialize_from_config(
|
AscendSchedulerConfig.initialize_from_config(
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from vllm.config import SchedulerConfig
|
|||||||
class AscendSchedulerConfig(SchedulerConfig):
|
class AscendSchedulerConfig(SchedulerConfig):
|
||||||
enable_chunked_prefill: bool = False
|
enable_chunked_prefill: bool = False
|
||||||
policy: str = "fcfs"
|
policy: str = "fcfs"
|
||||||
num_scheduler_steps: int = 1
|
|
||||||
scheduler_cls: Union[str, Type[object]] = (
|
scheduler_cls: Union[str, Type[object]] = (
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
enable_pd_transfer: bool = False
|
enable_pd_transfer: bool = False
|
||||||
@@ -44,7 +43,6 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
# Override default values into original SchedulerConfig
|
# Override default values into original SchedulerConfig
|
||||||
scheduler_config["enable_chunked_prefill"] = False
|
scheduler_config["enable_chunked_prefill"] = False
|
||||||
scheduler_config["policy"] = "fcfs"
|
scheduler_config["policy"] = "fcfs"
|
||||||
scheduler_config["num_scheduler_steps"] = 1
|
|
||||||
scheduler_config["scheduler_cls"] = (
|
scheduler_config["scheduler_cls"] = (
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
scheduler_config["enable_pd_transfer"] = False
|
scheduler_config["enable_pd_transfer"] = False
|
||||||
@@ -76,9 +74,6 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"currently AscendScheduler only supports LLM models.")
|
"currently AscendScheduler only supports LLM models.")
|
||||||
if self.num_scheduler_steps > 1:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"currently AscendScheduler doesn't support multi-step.")
|
|
||||||
if self.send_delta_data:
|
if self.send_delta_data:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"currently AscendScheduler doesn't support send_delta_data.")
|
"currently AscendScheduler doesn't support send_delta_data.")
|
||||||
|
|||||||
@@ -128,6 +128,35 @@ class NPUPlatform(Platform):
|
|||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
|
decoding_config = vllm_config.decoding_config
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||||
|
|
||||||
|
if model_config is not None and not model_config.use_mla:
|
||||||
|
logger.info(
|
||||||
|
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
|
||||||
|
"as the performance of operators supporting this feature "
|
||||||
|
"functionality is currently suboptimal.")
|
||||||
|
if not model_config.is_multimodal_model and \
|
||||||
|
decoding_config.backend == "auto" and \
|
||||||
|
not scheduler_config.delay_factor > 0 and \
|
||||||
|
not scheduler_config.send_delta_data and \
|
||||||
|
scheduler_config.policy == "fcfs":
|
||||||
|
ascend_scheduler_config.enabled = True
|
||||||
|
chunked_prefill_enabled_in_ascend_scheduler = getattr(
|
||||||
|
ascend_scheduler_config, "enable_chunked_prefill", False)
|
||||||
|
if chunked_prefill_enabled_in_ascend_scheduler:
|
||||||
|
logger.warning(
|
||||||
|
"Chunked prefill feature is enabled in ascend_scheduler,"
|
||||||
|
"but note that the operator supporting this feature "
|
||||||
|
"would lead to performance degradation.")
|
||||||
|
# In this situation, max_num_batched_tokens would have been rewritten.
|
||||||
|
# So we must make sure max_num_batched_tokens is not smaller than max_model_len.
|
||||||
|
if (scheduler_config.max_num_batched_tokens
|
||||||
|
< scheduler_config.max_model_len
|
||||||
|
and not chunked_prefill_enabled_in_ascend_scheduler):
|
||||||
|
scheduler_config.max_num_batched_tokens = scheduler_config.max_model_len
|
||||||
|
|
||||||
kv_cache_dtype = vllm_config.additional_config.get(
|
kv_cache_dtype = vllm_config.additional_config.get(
|
||||||
"kv_cache_dtype", None)
|
"kv_cache_dtype", None)
|
||||||
if kv_cache_dtype is not None:
|
if kv_cache_dtype is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user