[main] add pd transfer for ascend scheduler (#2753)

### What this PR does / why we need it?
For offline scenarios, adjust the scheduling process to prioritize the
prefill phase of all requests, then process the decode phase of all
requests.

### How was this patch tested?

```
max_num_seqs=24,
additional_config={
    "ascend_scheduler_config":{
        "enabled": True,
        "enable_pd_transfer": True,
        "decode_max_num_seqs": 24,
        "enable_chunked_prefill": False
    }
},
```
| input | output | num prompts | max_num_seqs | dp | tp | scheduler |
tps |
| ------ | ------ | ---------- | ---------------- | ---- | ---- |
---------------- | --------------- |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | v1 | 234.06 |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | pd transfer | 239.59(+2.4%) |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | v1 | 222.85 |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | pd transfer | 225.81(+1.3%) |


- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163

---------

Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
CaranLic
2025-09-10 08:46:39 +08:00
committed by GitHub
parent edf1f600ad
commit 168ad600b5
9 changed files with 216 additions and 4 deletions

View File

@@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
enable_pd_transfer: bool = False
decode_max_num_seqs: int = 0
@classmethod
def initialize_from_config(
@@ -45,6 +47,8 @@ class AscendSchedulerConfig(SchedulerConfig):
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):

View File

@@ -52,6 +52,15 @@ class AscendScheduler(Scheduler):
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []
self.finished_prefill_reqs: deque[Request] = deque()
enable_pd_transfer = getattr(self.scheduler_config,
'enable_pd_transfer', False)
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.phase = "" if not enable_pd_transfer else "prefill"
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
decode_max_num_seqs)
def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
@@ -76,9 +85,25 @@ class AscendScheduler(Scheduler):
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()
if self.phase == "prefill":
remaining_running_reqs = []
for request in self.running:
# move request has finished prefill to finished_prefill_reqs
if request.num_tokens > request.num_prompt_tokens:
self.finished_prefill_reqs.append(request)
else:
remaining_running_reqs.append(request)
self.running = remaining_running_reqs
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
if len(self.running) == (self.decode_max_num_running_reqs
if self.phase == "decode" else
self.max_num_running_reqs):
break
request = self.waiting[0]
@@ -235,6 +260,13 @@ class AscendScheduler(Scheduler):
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
if self.phase == "decode":
while len(
self.running
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
request = self.finished_prefill_reqs.popleft()
self.running.append(request)
# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
@@ -334,7 +366,9 @@ class AscendScheduler(Scheduler):
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(
self.running
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)