[Feat] Dynamic Batch Feature (#3490)

[RFC](https://github.com/vllm-project/vllm-ascend/issues/3328) for more
details.
Add dynamic batch feature in chunked prefilling strategy, the token
budget can be refined to achieve better effective throughput and TPOT.

!!! NOTE: only 910B3 is supported till now, we are working on further
improvements.
Additional file for lookup table is required.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: Cheng Wang <wangchengkyrie@outlook.com>
This commit is contained in:
KyrieWang
2025-10-22 14:13:32 +08:00
committed by GitHub
parent c18ca62a17
commit 60e2be1b36
10 changed files with 1368 additions and 1 deletions

View File

@@ -0,0 +1,51 @@
# Dynamic Batch
Dynamic batch is a technique that dynamically adjusts the chunksize during each inference iteration within the chunked prefilling strategy according to the resources and SLO targets, thereby improving the effective throughput and decreasing the TBT.
Dynamic batch is controlled by the value of the `--SLO_limits_for_dynamic_batch`.
Notably, only 910 B3 is supported with decode token numbers scales below 2048 so far.
Especially, the improvements are quite obvious on Qwen, Llama models.
We are working on further improvements and this feature will support more XPUs in the future.
## Getting started
### Prerequisites
1. Dynamic batch now depends on a offline cost model saved in a look-up table to refine the token budget. The lookup-table is saved in '.csv' file, which should be first downloaded from [here](https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/dynamic_batch_scheduler/A2-B3-BLK128.csv), renamed, and saved to the path `vllm_ascend/core/profile_table.csv`
2. `Pandas` is needed to load the look-up table, in case `pandas` is not installed.
```bash
pip install pandas
```
### Tuning Parameter
`--SLO_limits_for_dynamic_batch` is the tuning parameters (integer type) for the dynamic batch feature, greater values impose more constraints on the latency limitation, leading to higher effective throughput. The parameter can be selected according to the specific models or service requirements.
```python
--SLO_limits_for_dynamic_batch =-1 # default value, dynamic batch disabled.
--SLO_limits_for_dynamic_batch = 0 # baseline value for dynamic batch, dynamic batch disabled, FCFS and decode-first chunked prefilling strategy is used.
--SLO_limits_for_dynamic_batch > 0 # user-defined value for dynamic batch, dynamic batch enabled with FCFS and decode-first chunked prefilling strategy.
```
### Supported Models
So far, dynamic batch performs better on several dense models including Qwen and Llama (from 8B to 32B) with `tensor_parallel_size=8`. For different models, a proper `SLO_limits_for_dynamic_batch` parameter is needed. The empirical value of this parameter is generally `35, 50, or 75`. Therefore, some additional tests are needed to select the best parameter.
## Usage
Dynamic batch is used in the online inference. A fully executable example is as follows:
```shell
SLO_LITMIT=50
vllm serve Qwen/Qwen2.5-14B-Instruct\
--additional_config '{"SLO_limits_for_dynamic_batch":'${SLO_LITMIT}'}' \
--max-num-seqs 256 \
--block-size 128 \
--tensor_parallel_size 8 \
--load_format dummy \
--max_num_batched_tokens 1024 \
--max_model_len 9000 \
--host localhost \
--port 12091 \
--gpu-memory-utilization 0.9 \
--trust-remote-code
```

View File

@@ -11,4 +11,5 @@ sleep_mode
structured_output
lora
eplb_swift_balancer
dynamic_batch
:::

View File

@@ -10,6 +10,8 @@ requires = [
"pybind11",
"pyyaml",
"scipy",
"pandas",
"pandas-stubs",
"setuptools>=64",
"setuptools-scm>=8",
"torch-npu==2.7.1.dev20250724",

View File

@@ -8,11 +8,13 @@ pip
pybind11
pyyaml
scipy
pandas
setuptools>=64
setuptools-scm>=8
torch>=2.7.1
torchvision
wheel
pandas-stubs
# requirements for disaggregated prefill
msgpack

View File

@@ -91,6 +91,43 @@ def test_chunked_prefill_with_ascend_scheduler(
)
@pytest.mark.parametrize("max_tokens",
[4]) # cannot align results when max_tokens > 4
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
def test_chunked_prefill_with_scheduler_dynamic_batch(
max_tokens: int, chunked_prefill_token_size: int) -> None:
example_prompts = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
]
max_num_seqs = chunked_prefill_token_size
max_num_batched_tokens = chunked_prefill_token_size
with VllmRunner(MODEL,
additional_config={
'SLO_limits_for_dynamic_batch': 0,
},
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
dynamic_batch_output = vllm_model.generate_greedy(
example_prompts, max_tokens)
with VllmRunner(MODEL,
additional_config={
'SLO_limits_for_dynamic_batch': -1,
},
max_model_len=2048,
gpu_memory_utilization=0.7) as vllm_model:
vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=vllm_output,
outputs_1_lst=dynamic_batch_output,
name_0="vllm_output",
name_1="chunked_prefill_output",
)
def test_async_scheduling() -> None:
prompts = [
"Hello, my name is",

View File

@@ -21,6 +21,7 @@ from vllm.v1.structured_output import StructuredOutputManager
from tests.ut.base import TestBase
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.core.scheduler_dynamic_batch import SchedulerDynamicBatch
EOS_TOKEN_ID = 50256
MODEL = "Qwen3-0.6B"
@@ -805,3 +806,665 @@ class TestAscendScheduler(TestBase):
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(scheduler.phase, "decode")
class TestSchedulerDynamicBatch(TestBase):
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [100, 100]
use_kv_connector = False
block_size = 16
scheduler_config = SchedulerConfig(
max_num_seqs=16,
max_model_len=MAX_NUM_BATCHED_TOKENS,
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
disable_chunked_mm_input=False,
enable_chunked_prefill=True,
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
)
scheduler_config.max_num_encoder_input_tokens = 10000
scheduler_config.encoder_cache_size = 10000
scheduler_config.chunked_prefill_enabled = True
scheduler_config.SLO_limits_for_dynamic_batch = 0
model_config = ModelConfig(
model=MODEL,
task="auto",
tokenizer=MODEL,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
max_model_len=MAX_NUM_BATCHED_TOKENS,
)
model_config.pooler_config = MagicMock()
model_config.multimodal_config = MagicMock()
model_config.hf_config = MagicMock()
model_config.hf_config.is_encoder_decoder = False
# Cache config, optionally force APC
kwargs_cache: Dict[str,
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
'enable_prefix_caching':
ENABLE_PREFIX_CACHING
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if NUM_SPECULATIVE_TOKENS is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1,
torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
scheduler = SchedulerDynamicBatch(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=MagicMock(spec=StructuredOutputManager),
)
should_advance = MagicMock()
should_advance.return_value = False
scheduler.structured_output_manager.should_advance = should_advance
return scheduler
def test_add_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
self.assertIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), i + 1)
def test_finish_request(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
self.assertNotIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), 9 - i)
def test_get_num_unfinished_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
self.assertEqual(scheduler.get_num_unfinished_requests(),
len(requests) - i - 1)
def test_schedule(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = True
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_multimodal_requests(self):
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = True
mm_positions = [[PlaceholderRange(offset=i, length=10)]
for i in range(10)]
requests = create_requests(
num_requests=10,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
assert len(encoder_input) == 1
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
global ENABLE_PREFIX_CACHING
ENABLE_PREFIX_CACHING = True
global PROMPT_LOGPROBS
PROMPT_LOGPROBS = 5
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_stop_via_update_from_output(self):
"""Test stopping behavior through update_from_output"""
global NUM_SPECULATIVE_TOKENS
NUM_SPECULATIVE_TOKENS = 1
scheduler = self.create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
# Test case 2: Stop on custom stop token
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertEqual(requests[0].stop_reason, 42)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
# Test case 3: Stop on max tokens
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status,
RequestStatus.FINISHED_LENGTH_CAPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
self.assertEqual(list(requests[1].output_token_ids), [13])
# Test case 4: Ignore EOS flag
scheduler = self.create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
self.assertEqual(len(scheduler.running), 1)
self.assertFalse(requests[0].is_finished())
self.assertEqual(list(requests[0].output_token_ids),
[EOS_TOKEN_ID, 10, 11])
def test_schedule_concurrent_batches(self):
global MAX_NUM_BATCHED_TOKENS
global ENABLE_PREFIX_CACHING
global ENABLE_CHUNKED_PREFILL
global MAX_NUM_SEQS
global PROMPT_LOGPROBS
ENABLE_PREFIX_CACHING = None
MAX_NUM_BATCHED_TOKENS = 1024
MAX_NUM_SEQS = 2
ENABLE_CHUNKED_PREFILL = True
PROMPT_LOGPROBS = None
enable_prefix_caching_list = [None, True]
prompt_logprobs_list = [None, 5]
for i in range(len(enable_prefix_caching_list)):
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
PROMPT_LOGPROBS = prompt_logprobs_list[i]
scheduler = self.create_scheduler()
requests = create_requests(
num_requests=2,
num_tokens=512,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
512)
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
512)
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0,
model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler.schedule()
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1,
model_runner_output)
def test_schedule_spec_decoding_stats(self):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
[[1, 2], [3]], [[1]], [[]],
[[1, 2, 3], [4, 5, 6]]]
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
[[1, 2, 5], [3, 4]],
[[1, 2]], [[5]],
[[1, 2, 7], [4, 8]]]
expected_list: List[Tuple[int, int,
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
(1, 3, 1, [1, 0, 0]),
(2, 3, 3, [2, 1]),
(1, 1, 1, [1]),
(0, 0, 0, [0]),
(2, 6, 3, [2, 1, 0])]
global NUM_SPECULATIVE_TOKENS
for idx in range(len(spec_tokens_list)):
spec_tokens = spec_tokens_list[idx]
output_tokens = output_tokens_list[idx]
expected = expected_list[idx]
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
NUM_SPECULATIVE_TOKENS = num_spec_tokens
scheduler = self.create_scheduler()
requests = create_requests(num_requests=len(spec_tokens),
num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
self.assertEqual(running_req.num_computed_tokens, 1)
# The prompt token and the sampled token
self.assertEqual(running_req.num_tokens, 2)
# The prompt token, the sampled token, and the speculated tokens
self.assertEqual(running_req.num_tokens_with_spec,
2 + len(spec_tokens[i]))
# No draft or accepted tokens counted yet
self.assertTrue(
not engine_core_outputs
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
is None))
# Schedule the speculated tokens for validation
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), 0)
# The sampled token and speculated tokens
self.assertEqual(
output.total_num_scheduled_tokens,
len(requests) + sum(len(ids) for ids in spec_tokens))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id],
1 + len(spec_tokens[i]))
if spec_tokens[i]:
self.assertEqual(
len(output.scheduled_spec_decode_tokens[req_id]),
len(spec_tokens[i]))
else:
self.assertNotIn(req_id,
output.scheduled_spec_decode_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
self.assertIsNone(scheduler_stats.spec_decoding_stats)
else:
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
stats = scheduler_stats.spec_decoding_stats
self.assertEqual(stats.num_drafts, expected[0])
self.assertEqual(stats.num_draft_tokens, expected[1])
self.assertEqual(stats.num_accepted_tokens, expected[2])
self.assertEqual(stats.num_accepted_tokens_per_pos,
expected[3])
def assert_scheduler_empty(self, scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
scheduler = self.create_scheduler()
self.assertEqual(len(scheduler.requests), 0)
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), 0)
self.assertEqual(len(scheduler.finished_req_ids), 0)
# EncoderCacheManager.
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
# KVCache Manager.
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks), 0)
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block), 0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
free_block_queue.num_free_blocks)
self.assertEqual(
num_free_blocks,
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
self.assertEqual(block.ref_cnt, 0)
def test_memory_leak(self):
"""Test that we do not have a memory leak."""
scheduler = self.create_scheduler()
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)

View File

@@ -129,6 +129,8 @@ class AscendConfig:
if self.pd_tp_ratio == 0:
raise AssertionError(
"Only support P node tp size lagger then D node tp size")
self.SLO_limits_for_dynamic_batch = additional_config.get(
"SLO_limits_for_dynamic_batch", -1)
class TorchairGraphConfig:

View File

@@ -105,4 +105,4 @@ class AscendSchedulerConfig(SchedulerConfig):
if getattr(self, "scheduler_delay_factor", 0) > 0:
raise NotImplementedError(
"currently AscendScheduler doesn't support scheduler_delay_factor."
)
)

View File

@@ -0,0 +1,601 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
import time
import pandas as pd
from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
create_request_queue)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreEventType
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
class BudgetRefiner:
"""This budget refiner can make dynamic adjustment to the token budget
in the chunked prefill scheduling strategy."""
def __init__(self, default_budget, slo_limit=-1) -> None:
self.enabled = slo_limit > 0
if not self.enabled:
return
logger.info(
"Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it"
.format(str(slo_limit)))
self.lookup: dict[tuple[int, int], int] = {}
self.context_keys: set[int] = set()
self.dnum_keys: set[int] = set()
self.default_budget = default_budget
self._read_lookup_table(slo_limit)
def _read_lookup_table(self, slo_limit):
"""Load the lookup table for dynamic budget."""
base_dir = os.path.dirname(os.path.abspath(__file__))
table_file_path = os.path.join(base_dir, "profile_table.csv")
if not os.path.exists(table_file_path):
# proceed without dynamic batch
logger.error(
"The dynamic batching feature requires the lookup table "
"'profile_table.csv', but it was not found at '%s'. "
"Please download the corresponding table file.",
table_file_path)
self.enabled = False
return
else:
df = pd.read_csv(table_file_path)
grouped = df.groupby(['ctx_len', 'd_num'])
for (ctx_len, d_num), group in grouped:
valid = group[group['cost'] <= slo_limit]
if not valid.empty:
max_row = valid.loc[valid['chunk_size'].idxmax()]
self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size'])
self.context_keys.add(ctx_len)
self.dnum_keys.add(d_num)
self.context_keys = set(sorted(self.context_keys))
self.dnum_keys = set(sorted(self.dnum_keys))
def _align_key(self, value, valid_keys):
"""Align the minimum value within the valid_keys that is greater than the value."""
for k in valid_keys:
if k >= value:
return k
return None
def _get_max_budget(self, num_deocde_tokens, num_decode):
"""Get the maximum budget according to the number of decoding tokens and the decoding requests."""
aligned_ctx = self._align_key(num_deocde_tokens, self.context_keys)
aligned_dnum = self._align_key(num_decode, self.dnum_keys)
if aligned_ctx is None or aligned_dnum is None:
return self.default_budget
budget = self.lookup.get((aligned_ctx, aligned_dnum), None)
if budget is None:
logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}")
budget = self.default_budget
# For debug.
# logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}")
return budget
def refine_budget(self, running_request, budget):
"""Dynamically refine the token budget according to the running request."""
if not self.enabled:
return budget
# assume all running request will be scheduled.
num_decode_token_lst = [
req.num_tokens_with_spec \
for req in running_request \
if req.num_computed_tokens >= req.num_prompt_tokens ]
num_decode = len(num_decode_token_lst)
if num_decode <= 0:
return budget
num_deocde_tokens = sum(num_decode_token_lst) / num_decode
return self._get_max_budget(num_deocde_tokens, num_decode)
class SchedulerDynamicBatch(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler
with dynamic batch."""
def __init__(
self,
vllm_config: VllmConfig,
kv_cache_config: KVCacheConfig,
structured_output_manager: StructuredOutputManager,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
super().__init__(vllm_config, kv_cache_config,
structured_output_manager, mm_registry,
include_finished_set, log_stats)
self.running: list[Request] = []
self.budget_refiner = BudgetRefiner(
default_budget=self.scheduler_config.max_num_batched_tokens,
slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch)
def schedule(self) -> SchedulerOutput:
# NOTE: This scheduling algorithm is developed based on the "super.schedule()"
# with the implementations of the dynamic batch and some modifications:
# 1. Token budget can be dynamically refined according to the self.running
# through the BudgetRefiner;
# 2. This scheduling algorithm follows decode-first chunked prefills and FCFS
# strategy, which is slightly different to the "super.schedule()"
# 3. Similar to the "super.schedule()", at each step, the scheduler tries to
# assign tokens to the requests so that each request's num_computed_tokens can
# catch up its num_tokens_with_spec.
# 4. So far, the dynamic batch only supports 910B3 NPU. Further work will include
# more devices and finer optimization strategy.
scheduled_new_reqs: list[Request] = []
scheduled_resumed_reqs: list[Request] = []
scheduled_running_reqs: list[Request] = []
preempted_reqs: list[Request] = []
req_to_new_blocks: dict[str, KVCacheBlocks] = {}
num_scheduled_tokens: dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
token_budget = self.budget_refiner.refine_budget(
self.running, token_budget)
# NOTE: We move the prefill requests to the end of the self.running
# list and keep the relative order unchanged. This rearrangement makes this
# scheduling algorithm a strict decode-first chunked prefills.
d_lst = [
req for req in self.running
if req.num_computed_tokens >= req.num_prompt_tokens
]
p_lst = [
req for req in self.running
if req.num_computed_tokens < req.num_prompt_tokens
]
self.running = d_lst + p_lst
# Encoder-related.
scheduled_encoder_inputs: dict[str, list[int]] = {}
encoder_compute_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# First, schedule the RUNNING requests.
req_index = 0
while req_index < len(self.running) and token_budget > 0:
request = self.running[req_index]
num_new_tokens = (request.num_tokens_with_spec +
request.num_output_placeholders -
request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget)
# Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - 1 - request.num_computed_tokens)
# Schedule encoder inputs.
encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, request.num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reasons:
# 1. No new tokens to schedule. This may happen when
# (1) PP>1 and we have already scheduled all prompt tokens
# but they are not finished yet.
# (2) Async scheduling and the request has reached to either
# its max_total_tokens or max_model_len.
# 2. The encoder budget is exhausted.
# 3. The encoder cache is exhausted.
# NOTE(woosuk): Here, by doing `break` instead of `continue` as
# in v1 scheduler, we strictly follow the FCFS scheduling policy.
req_index += 1
break
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None:
# The request cannot be scheduled.
# Preempt the lowest-priority request.
if self.policy == SchedulingPolicy.PRIORITY:
preempted_req = max(
self.running,
key=lambda r: (r.priority, r.arrival_time),
)
self.running.remove(preempted_req)
if preempted_req in scheduled_running_reqs:
scheduled_running_reqs.remove(preempted_req)
else:
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
self.encoder_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
self.waiting.prepend_request(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
# The request can be scheduled.
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_blocks[request.request_id] = new_blocks
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
req_index += 1
# Speculative decode related.
if request.spec_token_ids:
num_scheduled_spec_tokens = (num_new_tokens +
request.num_computed_tokens -
request.num_tokens)
if num_scheduled_spec_tokens > 0:
# Trim spec_token_ids list to num_scheduled_spec_tokens.
del request.spec_token_ids[num_scheduled_spec_tokens:]
scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids)
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Record the LoRAs in scheduled_running_reqs
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in scheduled_running_reqs
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# Use a temporary RequestQueue to collect requests that need to be
# skipped and put back at the head of the waiting queue later
skipped_waiting_requests = create_request_queue(self.policy)
# Next, schedule the WAITING requests.
if not preempted_reqs:
while self.waiting and token_budget > 0:
if len(self.running) == self.max_num_running_reqs:
break
request = self.waiting.peek_request()
# KVTransfer: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
else:
logger.debug(
"%s is still in WAITING_FOR_REMOTE_KVS state.",
request.request_id)
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Skip request if the structured output request is still waiting
# for FSM compilation.
if request.status == RequestStatus.WAITING_FOR_FSM:
structured_output_req = request.structured_output_request
if structured_output_req and structured_output_req.grammar:
request.status = RequestStatus.WAITING
else:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Check that adding the request still respects the max_loras
# constraint.
if (self.lora_config and request.lora_request and
(len(scheduled_loras) == self.lora_config.max_loras and
request.lora_request.lora_int_id not in scheduled_loras)):
# Scheduling would exceed max_loras, skip.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if request.num_computed_tokens == 0:
# Get locally-cached tokens.
new_computed_blocks, num_new_local_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_new_local_computed_tokens))
if num_external_computed_tokens is None:
# The request cannot be scheduled because
# the KVConnector couldn't determine
# the number of matched tokens.
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
# Total computed tokens (local + external).
num_computed_tokens = (num_new_local_computed_tokens +
num_external_computed_tokens)
# KVTransfer: WAITING reqs have num_computed_tokens > 0
# after async KV recvs are completed.
else:
new_computed_blocks = (
self.kv_cache_manager.create_empty_block_list())
num_new_local_computed_tokens = 0
num_computed_tokens = request.num_computed_tokens
encoder_inputs_to_schedule = None
new_encoder_compute_budget = encoder_compute_budget
# KVTransfer: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
# Number of tokens to be scheduled.
else:
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold
< num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
# chunked prefill has to be enabled explicitly to allow
# pooling requests to be chunked
if not self.scheduler_config.chunked_prefill_enabled and \
num_new_tokens > token_budget:
self.waiting.pop_request()
skipped_waiting_requests.prepend_request(request)
continue
num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens > 0
# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_compute_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_compute_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break
# Handles an edge case when P/D Disaggregation
# is used with Spec Decoding where an
# extra block gets allocated which
# creates a mismatch between the number
# of local and remote blocks.
effective_lookahead_tokens = (0 if request.num_computed_tokens
== 0 else
self.num_lookahead_tokens)
# Determine if we need to allocate cross-attention blocks.
if self.is_encoder_decoder and request.has_encoder_inputs:
# TODO(russellb): For Whisper, we know that the input is
# always padded to the maximum length. If we support other
# encoder-decoder models, this will need to be updated if we
# want to only allocate what is needed.
num_encoder_tokens =\
self.scheduler_config.max_num_encoder_input_tokens
else:
num_encoder_tokens = 0
new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=effective_lookahead_tokens,
delay_cache_blocks=load_kv_async,
num_encoder_tokens=num_encoder_tokens,
)
if new_blocks is None:
# The request cannot be scheduled.
break
# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request = self.waiting.pop_request()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.prepend_request(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(
f"Invalid request status: {request.status}")
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
req_to_new_blocks[request.request_id] = (
self.kv_cache_manager.get_blocks(request.request_id))
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
scheduled_encoder_inputs[request.request_id] = (
encoder_inputs_to_schedule)
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_compute_budget = new_encoder_compute_budget
# Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests:
self.waiting.prepend_requests(skipped_waiting_requests)
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
# Since some requests in the RUNNING queue may not be scheduled in
# this step, the total number of scheduled requests can be smaller than
# len(self.running).
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) <= len(self.running))
# Get the longest common prefix among all requests in the running queue.
# This can be potentially used for cascade attention.
num_common_prefix_blocks = [0] * len(
self.kv_cache_config.kv_cache_groups)
if self.running:
any_request = self.running[0]
num_common_prefix_blocks = (
self.kv_cache_manager.get_num_common_prefix_blocks(
any_request, len(self.running)))
# Construct the scheduler output.
new_reqs_data = [
NewRequestData.from_request(
req, req_to_new_blocks[req.request_id].get_block_ids())
for req in scheduled_new_reqs
]
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs,
scheduled_resumed_reqs,
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs +
scheduled_resumed_reqs)
structured_output_request_ids, grammar_bitmask = (
self.get_grammar_bitmask(scheduled_requests,
scheduled_spec_decode_tokens))
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=cached_reqs_data,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
scheduled_encoder_inputs=scheduled_encoder_inputs,
num_common_prefix_blocks=num_common_prefix_blocks,
# finished_req_ids is an existing state in the scheduler,
# instead of being newly scheduled in this step.
# It contains the request IDs that are finished in between
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.
get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
# 1. Plan the KV cache store
# 2. Wrap up all the KV cache load / save ops into an opaque object
# 3. Clear the internal states of the connector
if self.connector is not None:
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
# collect KV cache events from KV cache manager
events = self.kv_cache_manager.take_events()
# collect KV cache events from connector
if self.connector is not None:
connector_events = self.connector.take_events()
if connector_events:
if events is None:
events = list(connector_events)
else:
events.extend(connector_events)
# publish collected KV cache events
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
self._update_after_schedule(scheduler_output)
return scheduler_output

View File

@@ -314,6 +314,14 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config)
vllm_config.scheduler_config = recompute_scheduler_config
# Extend original scheduler_config to use SchedulerDynamicBatch.
if ascend_config.SLO_limits_for_dynamic_batch != -1:
vllm_config.scheduler_config.scheduler_cls = (
"vllm_ascend.core.scheduler_dynamic_batch.SchedulerDynamicBatch"
)
vllm_config.scheduler_config.chunked_prefill_enabled = True
vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch
@classmethod
def get_attn_backend_cls(
cls,