diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index a20331e..076f7f2 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -60,6 +60,8 @@ The details of each config option are as follows: | `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine| | `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. | | `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. | +| `max_long_partial_prefills` | Union[int, float] | `float('inf')` | the maximum number of prompts longer than long_prefill_token_threshold that will be prefilled concurrently. | +| `long_prefill_token_threshold` | Union[int, float] | `float('inf')` | a request is considered long if the prompt is longer than this number of tokens. | ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. @@ -79,6 +81,8 @@ An example of additional configuration is as follows: "ascend_scheduler_config": { "enabled": True, "enable_chunked_prefill": True, + "max_long_partial_prefills": 1, + "long_prefill_token_threshold": 4096, }, "multistream_overlap_shared_expert": True, "refresh": False, diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index b0942da..84fd643 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -50,6 +50,8 @@ class TestAscendSchedulerConfig(TestBase): scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler", max_num_batched_tokens=2048, max_model_len=2048, + max_long_partial_prefills=1, + long_prefill_token_threshold=512, ), ) self.assertEqual(ascend_config.enable_chunked_prefill, False) @@ -58,6 +60,8 @@ class TestAscendSchedulerConfig(TestBase): "vllm_ascend.core.scheduler.AscendScheduler") self.assertEqual(ascend_config.max_num_batched_tokens, 2048) self.assertEqual(ascend_config.encoder_cache_size, 2048) + self.assertEqual(ascend_config.max_long_partial_prefills, 1) + self.assertEqual(ascend_config.long_prefill_token_threshold, 512) def test_not_implemented_policy(self): with self.assertRaises(NotImplementedError) as context: diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 095f392..d723e0a 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -221,7 +221,7 @@ class TestAscendScheduler(TestBase): len(requests) - i - 1) def test_schedule(self): - '''Test scheduling. + '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = self.create_scheduler() @@ -279,6 +279,27 @@ class TestAscendScheduler(TestBase): for i, request in enumerate(requests): self.assertEqual(scheduler.running[i], request) + def test_concurrent_partial_prefills_schedule(self): + '''Test concurrent partial prefills scheduling. + total requests = 10, every request has 10 token. + while set long_prefill_token_threshold = 1, scheduler can + only schedule max_long_partial_prefills long request. + ''' + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + scheduler.scheduler_config.max_long_partial_prefills = 2 + scheduler.scheduler_config.long_prefill_token_threshold = 1 + requests = create_requests(num_requests=10, num_tokens=20) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), + scheduler.scheduler_config.max_long_partial_prefills) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + def test_schedule_enable_prefix_caching(self): '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index dcd5d05..3736534 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -20,10 +20,14 @@ from typing import Type, Union from vllm.config import SchedulerConfig +MAX_INT = 2147483647 + @dataclass class AscendSchedulerConfig(SchedulerConfig): enable_chunked_prefill: bool = False + max_long_partial_prefills: int = MAX_INT + long_prefill_token_threshold: int = MAX_INT policy: str = "fcfs" scheduler_cls: Union[str, Type[object]] = ( "vllm_ascend.core.scheduler.AscendScheduler") @@ -42,6 +46,8 @@ class AscendSchedulerConfig(SchedulerConfig): } # Override default values into original SchedulerConfig scheduler_config["enable_chunked_prefill"] = False + scheduler_config["max_long_partial_prefills"] = None + scheduler_config["long_prefill_token_threshold"] = None scheduler_config["policy"] = "fcfs" scheduler_config["scheduler_cls"] = ( "vllm_ascend.core.scheduler.AscendScheduler") @@ -67,6 +73,28 @@ class AscendSchedulerConfig(SchedulerConfig): "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " "decrease max_model_len.") + # concurrent partial prefills. Default is inf + if self.max_long_partial_prefills is None: + self.max_long_partial_prefills = MAX_INT + self.long_prefill_token_threshold = MAX_INT + + if self.long_prefill_token_threshold is None or \ + self.long_prefill_token_threshold <= 0: + if self.max_model_len is None: + self.long_prefill_token_threshold = MAX_INT + else: + self.long_prefill_token_threshold = \ + max(1, int(self.max_model_len * 0.04)) + + if self.max_long_partial_prefills < 0: + raise ValueError( + f"max_long_partial_prefills must be non-negative, but got " + f"{self.max_long_partial_prefills}") + if self.long_prefill_token_threshold < 0: + raise ValueError( + f"long_prefill_token_threshold must be non-negative, but got " + f"{self.long_prefill_token_threshold}") + if self.policy != "fcfs": raise NotImplementedError( f"currently AscendScheduler only supports fcfs policy, got {self.policy}" diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index cc6822f..d77465d 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -102,6 +102,14 @@ class AscendScheduler(Scheduler): # all request prefilled, change phase to decode if not self.waiting and not self.running: self.phase = "decode" + # Skip long prompt requests in prefill stage. + # long_prefill_budget is float('inf') if not use. + if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0: + long_prefill_budget = float('inf') + long_prefill_token_threshold = float('inf') + else: + long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills + long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold # Schedule prefill requests first. while self.waiting and token_budget > 0: @@ -217,6 +225,11 @@ class AscendScheduler(Scheduler): skip_cur_request() continue + if num_new_tokens > long_prefill_token_threshold \ + and long_prefill_budget <= 0: + skip_cur_request() + continue + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, @@ -268,6 +281,8 @@ class AscendScheduler(Scheduler): # Update request info. num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens + if num_new_tokens > long_prefill_token_threshold: + long_prefill_budget -= 1 request.status = RequestStatus.RUNNING request.num_computed_tokens = num_computed_tokens # Count the number of prefix cached tokens.