[main] add pd transfer for ascend scheduler (#2753)

### What this PR does / why we need it?
For offline scenarios, adjust the scheduling process to prioritize the
prefill phase of all requests, then process the decode phase of all
requests.

### How was this patch tested?

```
max_num_seqs=24,
additional_config={
    "ascend_scheduler_config":{
        "enabled": True,
        "enable_pd_transfer": True,
        "decode_max_num_seqs": 24,
        "enable_chunked_prefill": False
    }
},
```
| input | output | num prompts | max_num_seqs | dp | tp | scheduler |
tps |
| ------ | ------ | ---------- | ---------------- | ---- | ---- |
---------------- | --------------- |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | v1 | 234.06 |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | pd transfer | 239.59(+2.4%) |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | v1 | 222.85 |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | pd transfer | 225.81(+1.3%) |


- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163

---------

Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
CaranLic
2025-09-10 08:46:39 +08:00
committed by GitHub
parent edf1f600ad
commit 168ad600b5
9 changed files with 216 additions and 4 deletions

View File

@@ -58,6 +58,8 @@ The details of each config option are as follows:
| Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.

View File

@@ -165,3 +165,16 @@ class TestAscendSchedulerConfig(TestBase):
)
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
self.assertIn("max_model_len (4096)", str(context.exception))
def test_initialize_from_config_with_pd_transfer(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_pd_transfer=True,
decode_max_num_seqs=48,
max_num_batched_tokens=4096,
max_model_len=4096,
),
)
self.assertEqual(ascend_config.enable_pd_transfer, True)
self.assertEqual(ascend_config.decode_max_num_seqs, 48)

View File

@@ -705,3 +705,34 @@ class TestAscendScheduler(TestBase):
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)
def test_scheduler_with_pd_transfer(self):
scheduler = self.create_scheduler()
scheduler.phase = "prefill"
requests = create_requests(num_requests=32)
for request in requests:
scheduler.add_request(request)
# 1st iteration, move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
first_iter_prefilled_req_num = len(scheduler.running)
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
scheduler.max_num_running_reqs)
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
# and move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(len(scheduler.finished_prefill_reqs),
first_iter_prefilled_req_num)
# 3rd iteration, all requests prefilled, change scheduler phase to decode
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(scheduler.phase, "decode")

View File

@@ -0,0 +1,40 @@
import torch
from pytest_mock import MockerFixture
from vllm.config import SchedulerConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
class TestMinPLogitsProcessorInitFunc(PytestBase):
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
device_cpu = torch.device("cpu")
device_npu = torch.device("npu")
is_pin_memory = False
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
mock_scheduler_config.decode_max_num_seqs = 0
mock_scheduler_config.max_num_seqs = 128
mock_vllm_config.scheduler_config = mock_scheduler_config
# torch.zeros/torch.empty returns error on online ut machine, so mock it
mock_tensor = torch.zeros((256, ),
dtype=torch.float32,
pin_memory=False)
mocker.patch("torch.zeros", return_value=mock_tensor)
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
mocker.patch("torch.empty", return_value=mock_empty_tensor)
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
is_pin_memory)
assert processor_cpu.min_p is not None
assert processor_cpu.use_double_tensor is False
assert processor_cpu.min_p_cpu.shape[0] == 256
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
is_pin_memory)
assert processor_cpu.min_p is not None
assert processor_cpu.use_double_tensor is True
assert processor_cpu.min_p_cpu.shape[0] == 256

View File

@@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
num_scheduler_steps: int = 1
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.scheduler.AscendScheduler")
enable_pd_transfer: bool = False
decode_max_num_seqs: int = 0
@classmethod
def initialize_from_config(
@@ -45,6 +47,8 @@ class AscendSchedulerConfig(SchedulerConfig):
scheduler_config["num_scheduler_steps"] = 1
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.scheduler.AscendScheduler")
scheduler_config["enable_pd_transfer"] = False
scheduler_config["decode_max_num_seqs"] = 0
# Override params in original SchedulerConfig with params in ascend_scheduler_config
for k, _ in scheduler_config.items():
if hasattr(ascend_scheduler_config, k):

View File

@@ -52,6 +52,15 @@ class AscendScheduler(Scheduler):
self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = []
self.finished_prefill_reqs: deque[Request] = deque()
enable_pd_transfer = getattr(self.scheduler_config,
'enable_pd_transfer', False)
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.phase = "" if not enable_pd_transfer else "prefill"
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
decode_max_num_seqs)
def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled:
return super().schedule()
@@ -76,9 +85,25 @@ class AscendScheduler(Scheduler):
# and put back at the head of the waiting queue later
skipped_waiting_requests: deque[Request] = deque()
if self.phase == "prefill":
remaining_running_reqs = []
for request in self.running:
# move request has finished prefill to finished_prefill_reqs
if request.num_tokens > request.num_prompt_tokens:
self.finished_prefill_reqs.append(request)
else:
remaining_running_reqs.append(request)
self.running = remaining_running_reqs
# all request prefilled, change phase to decode
if not self.waiting and not self.running:
self.phase = "decode"
# Schedule prefill requests first.
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
if len(self.running) == (self.decode_max_num_running_reqs
if self.phase == "decode" else
self.max_num_running_reqs):
break
request = self.waiting[0]
@@ -235,6 +260,13 @@ class AscendScheduler(Scheduler):
if skipped_waiting_requests:
self.waiting.extendleft(skipped_waiting_requests)
if self.phase == "decode":
while len(
self.running
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
request = self.finished_prefill_reqs.popleft()
self.running.append(request)
# If no prefill requests are scheduled,
# Schedule decode requests next.
if len(self.scheduled_req_ids) == 0:
@@ -334,7 +366,9 @@ class AscendScheduler(Scheduler):
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert len(
self.running
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
scheduled_running_reqs) <= len(self.running)

View File

@@ -0,0 +1,50 @@
import itertools
from collections.abc import Sequence
from typing import TYPE_CHECKING, Union
import torch
from vllm.logger import init_logger
from vllm.v1.sample import logits_processor
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinTokensLogitsProcessor)
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
from vllm.v1.sample.logits_processor.state import LogitsProcessors
from vllm_ascend.sample.logits_processor.builtin import \
AscendMinPLogitsProcessor
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
# Error message when the user tries to initialize vLLM with a pooling model
# and custom logitsproces
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
" logits processors.")
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,
LogitBiasLogitsProcessor,
AscendMinPLogitsProcessor,
]
def build_logitsprocs(
vllm_config: "VllmConfig",
device: torch.device,
is_pin_memory: bool,
is_pooling_model: bool,
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
) -> LogitsProcessors:
if is_pooling_model:
if custom_logitsprocs:
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
logger.debug("Skipping logits processor loading because pooling models"
" do not support logits processors.")
return LogitsProcessors()
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
custom_logitsprocs)
return LogitsProcessors(
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))

View File

@@ -0,0 +1,35 @@
import torch
from vllm.config import VllmConfig
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool):
super().__init__(vllm_config, device, is_pin_memory)
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
'decode_max_num_seqs', 0)
if decode_max_num_seqs != 0:
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.min_p_count: int = 0
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=is_pin_memory)
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
self.use_double_tensor = torch.device(device).type != "cpu"
if self.use_double_tensor:
# Pre-allocated device tensor
self.min_p_device: torch.Tensor = torch.empty(
(max_num_reqs, ), dtype=torch.float32, device=device)
else:
self.min_p_device = self.min_p_cpu_tensor
# Current slice of the device tensor
self.min_p: torch.Tensor = self.min_p_device[:0]

View File

@@ -66,7 +66,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -86,6 +85,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
from vllm_ascend.spec_decode import get_spec_decode_method
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
@@ -173,7 +173,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
self.block_size)
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
decode_max_num_seqs = getattr(self.scheduler_config,
'decode_max_num_seqs', 0)
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
decode_max_num_seqs)
self.dp_size = vllm_config.parallel_config.data_parallel_size
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.device = device