[main] add pd transfer for ascend scheduler (#2753)
### What this PR does / why we need it?
For offline scenarios, adjust the scheduling process to prioritize the
prefill phase of all requests, then process the decode phase of all
requests.
### How was this patch tested?
```
max_num_seqs=24,
additional_config={
"ascend_scheduler_config":{
"enabled": True,
"enable_pd_transfer": True,
"decode_max_num_seqs": 24,
"enable_chunked_prefill": False
}
},
```
| input | output | num prompts | max_num_seqs | dp | tp | scheduler |
tps |
| ------ | ------ | ---------- | ---------------- | ---- | ---- |
---------------- | --------------- |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | v1 | 234.06 |
| dapo-math-17K | 2K | 384 | 24 | 2 | 1 | pd transfer | 239.59(+2.4%) |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | v1 | 222.85 |
| dapo-math-17K| 2K | 384 | 24 | 4 | 1 | pd transfer | 225.81(+1.3%) |
- vLLM version: v0.10.1.1
- vLLM main:
6fb2788163
---------
Signed-off-by: CaranLic <740821011@qq.com>
This commit is contained in:
@@ -165,3 +165,16 @@ class TestAscendSchedulerConfig(TestBase):
|
||||
)
|
||||
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
|
||||
self.assertIn("max_model_len (4096)", str(context.exception))
|
||||
|
||||
def test_initialize_from_config_with_pd_transfer(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
enable_pd_transfer=True,
|
||||
decode_max_num_seqs=48,
|
||||
max_num_batched_tokens=4096,
|
||||
max_model_len=4096,
|
||||
),
|
||||
)
|
||||
self.assertEqual(ascend_config.enable_pd_transfer, True)
|
||||
self.assertEqual(ascend_config.decode_max_num_seqs, 48)
|
||||
|
||||
@@ -705,3 +705,34 @@ class TestAscendScheduler(TestBase):
|
||||
|
||||
# Confirm no memory leak.
|
||||
self.assert_scheduler_empty(scheduler)
|
||||
|
||||
def test_scheduler_with_pd_transfer(self):
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.phase = "prefill"
|
||||
requests = create_requests(num_requests=32)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# 1st iteration, move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
first_iter_prefilled_req_num = len(scheduler.running)
|
||||
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
|
||||
scheduler.max_num_running_reqs)
|
||||
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
|
||||
|
||||
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
|
||||
# and move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(len(scheduler.finished_prefill_reqs),
|
||||
first_iter_prefilled_req_num)
|
||||
|
||||
# 3rd iteration, all requests prefilled, change scheduler phase to decode
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(scheduler.phase, "decode")
|
||||
|
||||
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
40
tests/ut/sample/logits_processor/test_builtin.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.config import SchedulerConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.sample.logits_processor import AscendMinPLogitsProcessor
|
||||
|
||||
|
||||
class TestMinPLogitsProcessorInitFunc(PytestBase):
|
||||
|
||||
def test_init_func_with_decode_max_num_seqs(self, mocker: MockerFixture):
|
||||
device_cpu = torch.device("cpu")
|
||||
device_npu = torch.device("npu")
|
||||
is_pin_memory = False
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_scheduler_config = mocker.MagicMock(spec=SchedulerConfig)
|
||||
mock_scheduler_config.decode_max_num_seqs = 0
|
||||
mock_scheduler_config.max_num_seqs = 128
|
||||
mock_vllm_config.scheduler_config = mock_scheduler_config
|
||||
# torch.zeros/torch.empty returns error on online ut machine, so mock it
|
||||
mock_tensor = torch.zeros((256, ),
|
||||
dtype=torch.float32,
|
||||
pin_memory=False)
|
||||
mocker.patch("torch.zeros", return_value=mock_tensor)
|
||||
mock_empty_tensor = torch.empty((256, ), dtype=torch.float32)
|
||||
mocker.patch("torch.empty", return_value=mock_empty_tensor)
|
||||
|
||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_cpu,
|
||||
is_pin_memory)
|
||||
|
||||
assert processor_cpu.min_p is not None
|
||||
assert processor_cpu.use_double_tensor is False
|
||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||
|
||||
processor_cpu = AscendMinPLogitsProcessor(mock_vllm_config, device_npu,
|
||||
is_pin_memory)
|
||||
|
||||
assert processor_cpu.min_p is not None
|
||||
assert processor_cpu.use_double_tensor is True
|
||||
assert processor_cpu.min_p_cpu.shape[0] == 256
|
||||
Reference in New Issue
Block a user