[main] add pd transfer for ascend scheduler (#2753)
### What this PR does / why we need it?
For offline scenarios, adjust the scheduling process to prioritize the
prefill phase of all requests, then process the decode phase of all
requests.
### How was this patch tested?
```
max_num_seqs=24,
additional_config={
"ascend_scheduler_config":{
"enabled": True,
"enable_pd_transfer": True,
"decode_max_num_seqs": 24,
"enable_chunked_prefill": False
}
},
```
| input | output | num prompts | max_num_seqs | dp | tp | scheduler |
tps |
| ------ | ------ | ---------- | ---------------- | ---- | ---- |
---------------- | --------------- |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | v1 | 234.06 |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | pd transfer | 239.59(+2.4%) |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | v1 | 222.85 |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | pd transfer | 225.81(+1.3%) |
- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163
---------
Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
@@ -58,6 +58,8 @@ The details of each config option are as follows:
|
|||||||
| Name | Type | Default | Description |
|
| Name | Type | Default | Description |
|
||||||
| ---- | ---- | ------- | ----------- |
|
| ---- | ---- | ------- | ----------- |
|
||||||
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
|
| `enabled` | bool | `False` | Whether to enable ascend scheduler for V1 engine|
|
||||||
|
| `enable_pd_transfer` | bool | `False` | Whether to enable pd transfer. When using it, decode is started only when prefill of all requests is done. This option only takes effects on offline inference. |
|
||||||
|
| `decode_max_num_seqs` | int | `0` | Whether to change max_num_seqs of decode phase when enable pd transfer. This option only takes effects when enable_pd_transfer is True. |
|
||||||
|
|
||||||
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
||||||
|
|
||||||
|
|||||||
@@ -165,3 +165,16 @@ class TestAscendSchedulerConfig(TestBase):
|
|||||||
)
|
)
|
||||||
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
|
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
|
||||||
self.assertIn("max_model_len (4096)", str(context.exception))
|
self.assertIn("max_model_len (4096)", str(context.exception))
|
||||||
|
|
||||||
|
def test_initialize_from_config_with_pd_transfer(self):
|
||||||
|
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||||
|
self.basic_scheduler_config,
|
||||||
|
AscendSchedulerConfig(
|
||||||
|
enable_pd_transfer=True,
|
||||||
|
decode_max_num_seqs=48,
|
||||||
|
max_num_batched_tokens=4096,
|
||||||
|
max_model_len=4096,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.assertEqual(ascend_config.enable_pd_transfer, True)
|
||||||
|
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
|
||||||
|
|||||||
@@ -705,3 +705,34 @@ class TestAscendScheduler(TestBase):
|
|||||||
|
|
||||||
# Confirm no memory leak.
|
# Confirm no memory leak.
|
||||||
self.assert_scheduler_empty(scheduler)
|
self.assert_scheduler_empty(scheduler)
|
||||||
|
|
||||||
|
def test_scheduler_with_pd_transfer(self):
|
||||||
|
scheduler = self.create_scheduler()
|
||||||
|
scheduler.phase = "prefill"
|
||||||
|
requests = create_requests(num_requests=32)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# 1st iteration, move 16 requests from waiting to running for prefill
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
first_iter_prefilled_req_num = len(scheduler.running)
|
||||||
|
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
|
||||||
|
scheduler.max_num_running_reqs)
|
||||||
|
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
|
||||||
|
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
|
||||||
|
|
||||||
|
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
|
||||||
|
# and move 16 requests from waiting to running for prefill
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
self.assertEqual(len(scheduler.finished_prefill_reqs),
|
||||||
|
first_iter_prefilled_req_num)
|
||||||
|
|
||||||
|
# 3rd iteration, all requests prefilled, change scheduler phase to decode
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
self.assertEqual(scheduler.phase, "decode")
|
||||||
|
|||||||
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from vllm.config import SchedulerConfig, VllmConfig
|
||||||
|
|
||||||
|
from tests.ut.base import PytestBase
|
||||||
|
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class TestMinPLogitsProcessorInitFunc(PytestBase):
|
||||||
|
|
||||||
|
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
|
||||||
|
device_cpu = torch.device("cpu")
|
||||||
|
device_npu = torch.device("npu")
|
||||||
|
is_pin_memory = False
|
||||||
|
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||||
|
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
|
||||||
|
mock_scheduler_config.decode_max_num_seqs = 0
|
||||||
|
mock_scheduler_config.max_num_seqs = 128
|
||||||
|
mock_vllm_config.scheduler_config = mock_scheduler_config
|
||||||
|
# torch.zeros/torch.empty returns error on online ut machine, so mock it
|
||||||
|
mock_tensor = torch.zeros((256, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
pin_memory=False)
|
||||||
|
mocker.patch("torch.zeros", return_value=mock_tensor)
|
||||||
|
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
|
||||||
|
mocker.patch("torch.empty", return_value=mock_empty_tensor)
|
||||||
|
|
||||||
|
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
|
||||||
|
is_pin_memory)
|
||||||
|
|
||||||
|
assert processor_cpu.min_p is not None
|
||||||
|
assert processor_cpu.use_double_tensor is False
|
||||||
|
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||||
|
|
||||||
|
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
|
||||||
|
is_pin_memory)
|
||||||
|
|
||||||
|
assert processor_cpu.min_p is not None
|
||||||
|
assert processor_cpu.use_double_tensor is True
|
||||||
|
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||||
@@ -28,6 +28,8 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
num_scheduler_steps: int = 1
|
num_scheduler_steps: int = 1
|
||||||
scheduler_cls: Union[str, Type[object]] = (
|
scheduler_cls: Union[str, Type[object]] = (
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
|
enable_pd_transfer: bool = False
|
||||||
|
decode_max_num_seqs: int = 0
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def initialize_from_config(
|
def initialize_from_config(
|
||||||
@@ -45,6 +47,8 @@ class AscendSchedulerConfig(SchedulerConfig):
|
|||||||
scheduler_config["num_scheduler_steps"] = 1
|
scheduler_config["num_scheduler_steps"] = 1
|
||||||
scheduler_config["scheduler_cls"] = (
|
scheduler_config["scheduler_cls"] = (
|
||||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||||
|
scheduler_config["enable_pd_transfer"] = False
|
||||||
|
scheduler_config["decode_max_num_seqs"] = 0
|
||||||
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
# Override params in original SchedulerConfig with params in ascend_scheduler_config
|
||||||
for k, _ in scheduler_config.items():
|
for k, _ in scheduler_config.items():
|
||||||
if hasattr(ascend_scheduler_config, k):
|
if hasattr(ascend_scheduler_config, k):
|
||||||
|
|||||||
@@ -52,6 +52,15 @@ class AscendScheduler(Scheduler):
|
|||||||
self.scheduled_req_ids: set[str] = set()
|
self.scheduled_req_ids: set[str] = set()
|
||||||
self.running: list[Request] = []
|
self.running: list[Request] = []
|
||||||
|
|
||||||
|
self.finished_prefill_reqs: deque[Request] = deque()
|
||||||
|
enable_pd_transfer = getattr(self.scheduler_config,
|
||||||
|
'enable_pd_transfer', False)
|
||||||
|
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||||
|
'decode_max_num_seqs', 0)
|
||||||
|
self.phase = "" if not enable_pd_transfer else "prefill"
|
||||||
|
self.decode_max_num_running_reqs = max(self.max_num_running_reqs,
|
||||||
|
decode_max_num_seqs)
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
if self.scheduler_config.chunked_prefill_enabled:
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
return super().schedule()
|
return super().schedule()
|
||||||
@@ -76,9 +85,25 @@ class AscendScheduler(Scheduler):
|
|||||||
# and put back at the head of the waiting queue later
|
# and put back at the head of the waiting queue later
|
||||||
skipped_waiting_requests: deque[Request] = deque()
|
skipped_waiting_requests: deque[Request] = deque()
|
||||||
|
|
||||||
|
if self.phase == "prefill":
|
||||||
|
remaining_running_reqs = []
|
||||||
|
for request in self.running:
|
||||||
|
# move request has finished prefill to finished_prefill_reqs
|
||||||
|
if request.num_tokens > request.num_prompt_tokens:
|
||||||
|
self.finished_prefill_reqs.append(request)
|
||||||
|
else:
|
||||||
|
remaining_running_reqs.append(request)
|
||||||
|
self.running = remaining_running_reqs
|
||||||
|
# all request prefilled, change phase to decode
|
||||||
|
if not self.waiting and not self.running:
|
||||||
|
self.phase = "decode"
|
||||||
|
|
||||||
# Schedule prefill requests first.
|
# Schedule prefill requests first.
|
||||||
while self.waiting and token_budget > 0:
|
while self.waiting and token_budget > 0:
|
||||||
if len(self.running) == self.max_num_running_reqs:
|
if len(self.running) == (self.decode_max_num_running_reqs
|
||||||
|
if self.phase == "decode" else
|
||||||
|
self.max_num_running_reqs):
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
request = self.waiting[0]
|
request = self.waiting[0]
|
||||||
@@ -235,6 +260,13 @@ class AscendScheduler(Scheduler):
|
|||||||
if skipped_waiting_requests:
|
if skipped_waiting_requests:
|
||||||
self.waiting.extendleft(skipped_waiting_requests)
|
self.waiting.extendleft(skipped_waiting_requests)
|
||||||
|
|
||||||
|
if self.phase == "decode":
|
||||||
|
while len(
|
||||||
|
self.running
|
||||||
|
) < self.decode_max_num_running_reqs and self.finished_prefill_reqs:
|
||||||
|
request = self.finished_prefill_reqs.popleft()
|
||||||
|
self.running.append(request)
|
||||||
|
|
||||||
# If no prefill requests are scheduled,
|
# If no prefill requests are scheduled,
|
||||||
# Schedule decode requests next.
|
# Schedule decode requests next.
|
||||||
if len(self.scheduled_req_ids) == 0:
|
if len(self.scheduled_req_ids) == 0:
|
||||||
@@ -334,7 +366,9 @@ class AscendScheduler(Scheduler):
|
|||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||||
assert token_budget >= 0
|
assert token_budget >= 0
|
||||||
assert len(self.running) <= self.max_num_running_reqs
|
assert len(
|
||||||
|
self.running
|
||||||
|
) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs
|
||||||
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len(
|
||||||
scheduled_running_reqs) <= len(self.running)
|
scheduled_running_reqs) <= len(self.running)
|
||||||
|
|
||||||
|
|||||||
50
vllm_ascend/sample/logits_processor/__init__.py
Normal file
50
vllm_ascend/sample/logits_processor/__init__.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
import itertools
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.v1.sample import logits_processor
|
||||||
|
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
|
||||||
|
MinTokensLogitsProcessor)
|
||||||
|
from vllm.v1.sample.logits_processor.interface import LogitsProcessor
|
||||||
|
from vllm.v1.sample.logits_processor.state import LogitsProcessors
|
||||||
|
|
||||||
|
from vllm_ascend.sample.logits_processor.builtin import \
|
||||||
|
AscendMinPLogitsProcessor
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Error message when the user tries to initialize vLLM with a pooling model
|
||||||
|
# and custom logitsproces
|
||||||
|
STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom"
|
||||||
|
" logits processors.")
|
||||||
|
|
||||||
|
BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
|
||||||
|
MinTokensLogitsProcessor,
|
||||||
|
LogitBiasLogitsProcessor,
|
||||||
|
AscendMinPLogitsProcessor,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_logitsprocs(
|
||||||
|
vllm_config: "VllmConfig",
|
||||||
|
device: torch.device,
|
||||||
|
is_pin_memory: bool,
|
||||||
|
is_pooling_model: bool,
|
||||||
|
custom_logitsprocs: Sequence[Union[str, type[LogitsProcessor]]] = (),
|
||||||
|
) -> LogitsProcessors:
|
||||||
|
if is_pooling_model:
|
||||||
|
if custom_logitsprocs:
|
||||||
|
raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS)
|
||||||
|
logger.debug("Skipping logits processor loading because pooling models"
|
||||||
|
" do not support logits processors.")
|
||||||
|
return LogitsProcessors()
|
||||||
|
custom_logitsprocs_classes = logits_processor._load_custom_logitsprocs(
|
||||||
|
custom_logitsprocs)
|
||||||
|
return LogitsProcessors(
|
||||||
|
ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain(
|
||||||
|
BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes))
|
||||||
35
vllm_ascend/sample/logits_processor/builtin.py
Normal file
35
vllm_ascend/sample/logits_processor/builtin.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
import torch
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.v1.sample.logits_processor import MinPLogitsProcessor
|
||||||
|
|
||||||
|
|
||||||
|
class AscendMinPLogitsProcessor(MinPLogitsProcessor):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
|
||||||
|
is_pin_memory: bool):
|
||||||
|
super().__init__(vllm_config, device, is_pin_memory)
|
||||||
|
|
||||||
|
decode_max_num_seqs = getattr(vllm_config.scheduler_config,
|
||||||
|
'decode_max_num_seqs', 0)
|
||||||
|
if decode_max_num_seqs != 0:
|
||||||
|
max_num_reqs = max(vllm_config.scheduler_config.max_num_seqs,
|
||||||
|
decode_max_num_seqs)
|
||||||
|
|
||||||
|
self.min_p_count: int = 0
|
||||||
|
|
||||||
|
self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cpu",
|
||||||
|
pin_memory=is_pin_memory)
|
||||||
|
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||||
|
|
||||||
|
self.use_double_tensor = torch.device(device).type != "cpu"
|
||||||
|
|
||||||
|
if self.use_double_tensor:
|
||||||
|
# Pre-allocated device tensor
|
||||||
|
self.min_p_device: torch.Tensor = torch.empty(
|
||||||
|
(max_num_reqs, ), dtype=torch.float32, device=device)
|
||||||
|
else:
|
||||||
|
self.min_p_device = self.min_p_cpu_tensor
|
||||||
|
# Current slice of the device tensor
|
||||||
|
self.min_p: torch.Tensor = self.min_p_device[:0]
|
||||||
@@ -66,7 +66,6 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||||
LogprobsTensors, ModelRunnerOutput)
|
LogprobsTensors, ModelRunnerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import build_logitsprocs
|
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
@@ -86,6 +85,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
|||||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
@@ -173,7 +173,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
self.max_num_blocks_per_req = cdiv(self.model_config.max_model_len,
|
||||||
self.block_size)
|
self.block_size)
|
||||||
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
self.max_num_reqs = self.scheduler_config.max_num_seqs
|
decode_max_num_seqs = getattr(self.scheduler_config,
|
||||||
|
'decode_max_num_seqs', 0)
|
||||||
|
self.max_num_reqs = max(self.scheduler_config.max_num_seqs,
|
||||||
|
decode_max_num_seqs)
|
||||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|||||||
Reference in New Issue
Block a user