[CORE] concurrent partial prefills (#2372)

# What this PR does / why we need it?

When processing a mix of large and small requests, the TTFT of responses
is significantly reduc\ed. Please refer to
https://github.com/vllm-project/vllm/pull/10235, which achieves the same
effect by simply limiting the number of prompt fills for long requests.
This solution can be applied to both AscendScheduler (V0) and vLLM
Scheduler (V1). Tests show that TTFT can be significantly improved when
handling such mixed requests. However, This capability is currently
missing when Ascend Scheduler is enabled.

This benchmark used the Qwen3-8B model, with a context length of 128K,
running on a single card.

Regarding dataset selection, the sharegpt_clean dataset is used, with
its content concatenated and cropped. Small requests with token=50 and
medium requests with token=10240 were constructed (there were also large
requests with token=102400, but these were ignored because when using
the Prefill First scheduling strategy, max_num_batched_tokens will not
be set to such a large value). When loading vLLM, set
max_num_batched_tokens=22000. This length can accommodate two
medium-sized requests and some short requests, reflecting an extreme
scenario where the budget is almost entirely occupied by longer
requests.

Next, we mix 990 small requests and 100 medium requests into one type of
load scenario (hereinafter referred to as 10%), and similarly generate
load scenarios with 5% medium requests and 1% load scenarios.

Performance tests were conducted separately for enabling vLLMScheduler,
AscendScheduler, and AscendScheduler (long prompt concurrency set to 1).

- vLLM version: v0.10.2
- vLLM main:
1dfea5f4a9

---------

Signed-off-by: Csrayz <jover@cmbchina.com>
This commit is contained in:
Csrayz
2025-09-24 17:12:55 +08:00
committed by GitHub
parent 2d885869c5
commit 80524f5711
5 changed files with 73 additions and 1 deletions

View File

@@ -20,10 +20,14 @@ from typing import Type, Union
from vllm.config import SchedulerConfig
MAX_INT = 2147483647
@dataclass
class AscendSchedulerConfig(SchedulerConfig):
enable_chunked_prefill: bool = False
max_long_partial_prefills: int = MAX_INT
long_prefill_token_threshold: int = MAX_INT
policy: str = "fcfs"
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
@@ -42,6 +46,8 @@ class AscendSchedulerConfig(SchedulerConfig):
}
# Override default values into original SchedulerConfig
scheduler_config["enable_chunked_prefill"] = False
scheduler_config["max_long_partial_prefills"] = None
scheduler_config["long_prefill_token_threshold"] = None
scheduler_config["policy"] = "fcfs"
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
@@ -67,6 +73,28 @@ class AscendSchedulerConfig(SchedulerConfig):
"max_num_batched_tokens and makes vLLM reject longer "
"sequences. Please increase max_num_batched_tokens or "
"decrease max_model_len.")
# concurrent partial prefills. Default is inf
if self.max_long_partial_prefills is None:
self.max_long_partial_prefills = MAX_INT
self.long_prefill_token_threshold = MAX_INT
if self.long_prefill_token_threshold is None or \
self.long_prefill_token_threshold <= 0:
if self.max_model_len is None:
self.long_prefill_token_threshold = MAX_INT
else:
self.long_prefill_token_threshold = \
max(1, int(self.max_model_len * 0.04))
if self.max_long_partial_prefills < 0:
raise ValueError(
f"max_long_partial_prefills must be non-negative, but got "
f"{self.max_long_partial_prefills}")
if self.long_prefill_token_threshold < 0:
raise ValueError(
f"long_prefill_token_threshold must be non-negative, but got "
f"{self.long_prefill_token_threshold}")
if self.policy != "fcfs":
raise NotImplementedError(
f"currently AscendScheduler only supports fcfs policy, got {self.policy}"

View File

@@ -102,6 +102,14 @@ class AscendScheduler(Scheduler):
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
# Skip long prompt requests in prefill stage.
# long_prefill_budget is float('inf') if not use.
if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0:
long_prefill_budget = float('inf')
long_prefill_token_threshold = float('inf')
else:
long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills
long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold
# Schedule prefill requests first.
while self.waiting and token_budget > 0:
@@ -217,6 +225,11 @@ class AscendScheduler(Scheduler):
skip_cur_request()
continue
if num_new_tokens > long_prefill_token_threshold \
and long_prefill_budget <= 0:
skip_cur_request()
continue
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
@@ -268,6 +281,8 @@ class AscendScheduler(Scheduler):
# Update request info.
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
if num_new_tokens > long_prefill_token_threshold:
long_prefill_budget -= 1
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.