diff --git a/docs/source/user_guide/feature_guide/dynamic_batch.md b/docs/source/user_guide/feature_guide/dynamic_batch.md new file mode 100644 index 00000000..c1e76354 --- /dev/null +++ b/docs/source/user_guide/feature_guide/dynamic_batch.md @@ -0,0 +1,51 @@ +# Dynamic Batch + +Dynamic batch is a technique that dynamically adjusts the chunksize during each inference iteration within the chunked prefilling strategy according to the resources and SLO targets, thereby improving the effective throughput and decreasing the TBT. + +Dynamic batch is controlled by the value of the `--SLO_limits_for_dynamic_batch`. +Notably, only 910 B3 is supported with decode token numbers scales below 2048 so far. +Especially, the improvements are quite obvious on Qwen, Llama models. +We are working on further improvements and this feature will support more XPUs in the future. + +## Getting started + +### Prerequisites + +1. Dynamic batch now depends on a offline cost model saved in a look-up table to refine the token budget. The lookup-table is saved in '.csv' file, which should be first downloaded from [here](https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/dynamic_batch_scheduler/A2-B3-BLK128.csv), renamed, and saved to the path `vllm_ascend/core/profile_table.csv` + +2. `Pandas` is needed to load the look-up table, in case `pandas` is not installed. + + ```bash + pip install pandas + ``` + +### Tuning Parameter +`--SLO_limits_for_dynamic_batch` is the tuning parameters (integer type) for the dynamic batch feature, greater values impose more constraints on the latency limitation, leading to higher effective throughput. The parameter can be selected according to the specific models or service requirements. + +```python +--SLO_limits_for_dynamic_batch =-1 # default value, dynamic batch disabled. +--SLO_limits_for_dynamic_batch = 0 # baseline value for dynamic batch, dynamic batch disabled, FCFS and decode-first chunked prefilling strategy is used. +--SLO_limits_for_dynamic_batch > 0 # user-defined value for dynamic batch, dynamic batch enabled with FCFS and decode-first chunked prefilling strategy. +``` + +### Supported Models +So far, dynamic batch performs better on several dense models including Qwen and Llama (from 8B to 32B) with `tensor_parallel_size=8`. For different models, a proper `SLO_limits_for_dynamic_batch` parameter is needed. The empirical value of this parameter is generally `35, 50, or 75`. Therefore, some additional tests are needed to select the best parameter. + +## Usage +Dynamic batch is used in the online inference. A fully executable example is as follows: + +```shell +SLO_LITMIT=50 +vllm serve Qwen/Qwen2.5-14B-Instruct\ + --additional_config '{"SLO_limits_for_dynamic_batch":'${SLO_LITMIT}'}' \ + --max-num-seqs 256 \ + --block-size 128 \ + --tensor_parallel_size 8 \ + --load_format dummy \ + --max_num_batched_tokens 1024 \ + --max_model_len 9000 \ + --host localhost \ + --port 12091 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code +``` diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index 00f702a2..049e496f 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -11,4 +11,5 @@ sleep_mode structured_output lora eplb_swift_balancer +dynamic_batch ::: diff --git a/pyproject.toml b/pyproject.toml index 1a140ce8..68b66670 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ requires = [ "pybind11", "pyyaml", "scipy", + "pandas", + "pandas-stubs", "setuptools>=64", "setuptools-scm>=8", "torch-npu==2.7.1.dev20250724", diff --git a/requirements.txt b/requirements.txt index 7808e852..c133c628 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,11 +8,13 @@ pip pybind11 pyyaml scipy +pandas setuptools>=64 setuptools-scm>=8 torch>=2.7.1 torchvision wheel +pandas-stubs # requirements for disaggregated prefill msgpack diff --git a/tests/e2e/singlecard/test_ascend_scheduler.py b/tests/e2e/singlecard/test_ascend_scheduler.py index 916db51c..39bba024 100644 --- a/tests/e2e/singlecard/test_ascend_scheduler.py +++ b/tests/e2e/singlecard/test_ascend_scheduler.py @@ -91,6 +91,43 @@ def test_chunked_prefill_with_ascend_scheduler( ) +@pytest.mark.parametrize("max_tokens", + [4]) # cannot align results when max_tokens > 4 +@pytest.mark.parametrize("chunked_prefill_token_size", [16]) +def test_chunked_prefill_with_scheduler_dynamic_batch( + max_tokens: int, chunked_prefill_token_size: int) -> None: + example_prompts = [ + "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs." + ] + max_num_seqs = chunked_prefill_token_size + max_num_batched_tokens = chunked_prefill_token_size + with VllmRunner(MODEL, + additional_config={ + 'SLO_limits_for_dynamic_batch': 0, + }, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=2048, + gpu_memory_utilization=0.7) as vllm_model: + dynamic_batch_output = vllm_model.generate_greedy( + example_prompts, max_tokens) + + with VllmRunner(MODEL, + additional_config={ + 'SLO_limits_for_dynamic_batch': -1, + }, + max_model_len=2048, + gpu_memory_utilization=0.7) as vllm_model: + vllm_output = vllm_model.generate_greedy(example_prompts, max_tokens) + + check_outputs_equal( + outputs_0_lst=vllm_output, + outputs_1_lst=dynamic_batch_output, + name_0="vllm_output", + name_1="chunked_prefill_output", + ) + + def test_async_scheduling() -> None: prompts = [ "Hello, my name is", diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 9892188b..13a06b09 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -21,6 +21,7 @@ from vllm.v1.structured_output import StructuredOutputManager from tests.ut.base import TestBase from vllm_ascend.core.scheduler import AscendScheduler +from vllm_ascend.core.scheduler_dynamic_batch import SchedulerDynamicBatch EOS_TOKEN_ID = 50256 MODEL = "Qwen3-0.6B" @@ -805,3 +806,665 @@ class TestAscendScheduler(TestBase): model_runner_output = make_output(scheduler) scheduler.update_from_output(scheduler_output, model_runner_output) self.assertEqual(scheduler.phase, "decode") + + +class TestSchedulerDynamicBatch(TestBase): + + @patch("vllm.config.ModelConfig.__post_init__", MagicMock()) + @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) + @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') + def create_scheduler(self, mock_compute_encoder_budget): + mock_compute_encoder_budget.return_value = [100, 100] + use_kv_connector = False + block_size = 16 + + scheduler_config = SchedulerConfig( + max_num_seqs=16, + max_model_len=MAX_NUM_BATCHED_TOKENS, + long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD, + disable_chunked_mm_input=False, + enable_chunked_prefill=True, + max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS, + ) + + scheduler_config.max_num_encoder_input_tokens = 10000 + scheduler_config.encoder_cache_size = 10000 + scheduler_config.chunked_prefill_enabled = True + scheduler_config.SLO_limits_for_dynamic_batch = 0 + + model_config = ModelConfig( + model=MODEL, + task="auto", + tokenizer=MODEL, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + max_model_len=MAX_NUM_BATCHED_TOKENS, + ) + model_config.pooler_config = MagicMock() + model_config.multimodal_config = MagicMock() + model_config.hf_config = MagicMock() + model_config.hf_config.is_encoder_decoder = False + # Cache config, optionally force APC + kwargs_cache: Dict[str, + Any] = ({} if ENABLE_PREFIX_CACHING is None else { + 'enable_prefix_caching': + ENABLE_PREFIX_CACHING + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if NUM_SPECULATIVE_TOKENS is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + kv_cache_tensors=[], + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, + torch.float32, False)) + ], + ) + cache_config.num_gpu_blocks = 10000 + + scheduler = SchedulerDynamicBatch( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=MagicMock(spec=StructuredOutputManager), + ) + + should_advance = MagicMock() + should_advance.return_value = False + scheduler.structured_output_manager.should_advance = should_advance + + return scheduler + + def test_add_requests(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + self.assertIn(request.request_id, scheduler.requests) + self.assertEqual(len(scheduler.waiting), i + 1) + + def test_finish_request(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + self.assertNotIn(request.request_id, scheduler.requests) + self.assertEqual(len(scheduler.waiting), 9 - i) + + def test_get_num_unfinished_requests(self): + scheduler = self.create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + self.assertEqual(scheduler.get_num_unfinished_requests(), + len(requests) - i - 1) + + def test_schedule(self): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = True + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_schedule_multimodal_requests(self): + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = True + mm_positions = [[PlaceholderRange(offset=i, length=10)] + for i in range(10)] + requests = create_requests( + num_requests=10, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(len(output.scheduled_encoder_inputs), len(requests)) + for req_id, encoder_input in output.scheduled_encoder_inputs.items(): + assert len(encoder_input) == 1 + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_schedule_enable_prefix_caching(self): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + global ENABLE_PREFIX_CACHING + ENABLE_PREFIX_CACHING = True + global PROMPT_LOGPROBS + PROMPT_LOGPROBS = 5 + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + + def test_stop_via_update_from_output(self): + """Test stopping behavior through update_from_output""" + global NUM_SPECULATIVE_TOKENS + NUM_SPECULATIVE_TOKENS = 1 + scheduler = self.create_scheduler() + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[EOS_TOKEN_ID], [10, 11] + ], # First request hits EOS, second continues + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID]) + self.assertEqual(list(requests[1].output_token_ids), [10, 11]) + + # Test case 2: Stop on custom stop token + NUM_SPECULATIVE_TOKENS = 2 + scheduler = self.create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED) + self.assertEqual(requests[0].stop_reason, 42) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [10, 42]) + self.assertEqual(list(requests[1].output_token_ids), [13, 14]) + + # Test case 3: Stop on max tokens + NUM_SPECULATIVE_TOKENS = 2 + scheduler = self.create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + req.status = RequestStatus.RUNNING + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: + [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(requests) + }, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + self.assertEqual(len(scheduler.running), 1) + self.assertEqual(scheduler.running[0].request_id, + requests[1].request_id) + self.assertEqual(requests[0].status, + RequestStatus.FINISHED_LENGTH_CAPPED) + self.assertIn(requests[0].request_id, scheduler.finished_req_ids) + self.assertEqual(list(requests[0].output_token_ids), [10, 11]) + self.assertEqual(list(requests[1].output_token_ids), [13]) + + # Test case 4: Ignore EOS flag + scheduler = self.create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + self.assertEqual(len(scheduler.running), 1) + self.assertFalse(requests[0].is_finished()) + self.assertEqual(list(requests[0].output_token_ids), + [EOS_TOKEN_ID, 10, 11]) + + def test_schedule_concurrent_batches(self): + global MAX_NUM_BATCHED_TOKENS + global ENABLE_PREFIX_CACHING + global ENABLE_CHUNKED_PREFILL + global MAX_NUM_SEQS + global PROMPT_LOGPROBS + ENABLE_PREFIX_CACHING = None + MAX_NUM_BATCHED_TOKENS = 1024 + MAX_NUM_SEQS = 2 + ENABLE_CHUNKED_PREFILL = True + PROMPT_LOGPROBS = None + + enable_prefix_caching_list = [None, True] + prompt_logprobs_list = [None, 5] + + for i in range(len(enable_prefix_caching_list)): + ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i] + PROMPT_LOGPROBS = prompt_logprobs_list[i] + scheduler = self.create_scheduler() + requests = create_requests( + num_requests=2, + num_tokens=512, + ) + + # Schedule the first request. + scheduler.add_request(requests[0]) + scheduler_output0 = scheduler.schedule() + self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1) + self.assertEqual( + scheduler_output0.num_scheduled_tokens[requests[0].request_id], + 512) + + # The first request is still running, so only schedule the second request. + scheduler.add_request(requests[1]) + scheduler_output1 = scheduler.schedule() + self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1) + self.assertEqual( + scheduler_output1.num_scheduled_tokens[requests[1].request_id], + 512) + + # Model output of the first request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output0, + model_runner_output) + + # Schedule the next step. + # The first request can be scheduled again while the second + # request is still running. + scheduler.schedule() + # Model output of the second request. + model_runner_output = ModelRunnerOutput( + req_ids=[requests[1].request_id], + req_id_to_index={requests[1].request_id: 0}, + sampled_token_ids=[[0]], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + scheduler.update_from_output(scheduler_output1, + model_runner_output) + + def test_schedule_spec_decoding_stats(self): + """Test scheduling behavior with speculative decoding. + + This test verifies that: + 1. Speculated tokens get scheduled correctly + 2. Spec decoding stats properly count number of draft and accepted tokens + """ + spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]], + [[1, 2], [3]], [[1]], [[]], + [[1, 2, 3], [4, 5, 6]]] + output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]], + [[1, 2, 5], [3, 4]], + [[1, 2]], [[5]], + [[1, 2, 7], [4, 8]]] + expected_list: List[Tuple[int, int, + int, List[int]]] = [(1, 3, 3, [1, 1, 1]), + (1, 3, 1, [1, 0, 0]), + (2, 3, 3, [2, 1]), + (1, 1, 1, [1]), + (0, 0, 0, [0]), + (2, 6, 3, [2, 1, 0])] + + global NUM_SPECULATIVE_TOKENS + for idx in range(len(spec_tokens_list)): + spec_tokens = spec_tokens_list[idx] + output_tokens = output_tokens_list[idx] + expected = expected_list[idx] + num_spec_tokens = max(1, max(len(t) for t in spec_tokens)) + NUM_SPECULATIVE_TOKENS = num_spec_tokens + scheduler = self.create_scheduler() + requests = create_requests(num_requests=len(spec_tokens), + num_tokens=1) + req_ids = [] + req_to_index = {} + for i, request in enumerate(requests): + scheduler.add_request(request) + req_ids.append(request.request_id) + req_to_index[request.request_id] = i + + # Schedule a decode, which will also draft speculative tokens + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.total_num_scheduled_tokens, len(requests)) + for i in range(len(requests)): + req_id = requests[i].request_id + self.assertEqual(output.num_scheduled_tokens[req_id], 1) + self.assertNotIn(req_id, output.scheduled_spec_decode_tokens) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=[[0] for _ in range(len(requests))], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + + engine_core_outputs = scheduler.update_from_output( + output, model_runner_output) + scheduler.update_draft_token_ids(draft_token_ids) + + for i in range(len(requests)): + running_req = scheduler.running[i] + # The prompt token + self.assertEqual(running_req.num_computed_tokens, 1) + # The prompt token and the sampled token + self.assertEqual(running_req.num_tokens, 2) + # The prompt token, the sampled token, and the speculated tokens + self.assertEqual(running_req.num_tokens_with_spec, + 2 + len(spec_tokens[i])) + + # No draft or accepted tokens counted yet + self.assertTrue( + not engine_core_outputs + or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats + is None)) + + # Schedule the speculated tokens for validation + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), 0) + # The sampled token and speculated tokens + self.assertEqual( + output.total_num_scheduled_tokens, + len(requests) + sum(len(ids) for ids in spec_tokens)) + for i in range(len(requests)): + req_id = requests[i].request_id + self.assertEqual(output.num_scheduled_tokens[req_id], + 1 + len(spec_tokens[i])) + if spec_tokens[i]: + self.assertEqual( + len(output.scheduled_spec_decode_tokens[req_id]), + len(spec_tokens[i])) + else: + self.assertNotIn(req_id, + output.scheduled_spec_decode_tokens) + + model_runner_output = ModelRunnerOutput( + req_ids=req_ids, + req_id_to_index=req_to_index, + sampled_token_ids=output_tokens, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=[]) + + engine_core_outputs = scheduler.update_from_output( + output, model_runner_output) + + scheduler_stats = engine_core_outputs[0].scheduler_stats \ + if engine_core_outputs else None + if expected[0] == 0: + self.assertIsNone(scheduler_stats.spec_decoding_stats) + else: + self.assertIsNotNone(scheduler_stats.spec_decoding_stats) + stats = scheduler_stats.spec_decoding_stats + self.assertEqual(stats.num_drafts, expected[0]) + self.assertEqual(stats.num_draft_tokens, expected[1]) + self.assertEqual(stats.num_accepted_tokens, expected[2]) + self.assertEqual(stats.num_accepted_tokens_per_pos, + expected[3]) + + def assert_scheduler_empty(self, scheduler): + """Confirm the scheduler is "empty" - i.e. no leaks.""" + # Scheduler Metadata. + scheduler = self.create_scheduler() + self.assertEqual(len(scheduler.requests), 0) + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), 0) + self.assertEqual(len(scheduler.finished_req_ids), 0) + + # EncoderCacheManager. + self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0) + self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0) + + # KVCache Manager. + self.assertEqual( + len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + req_to_blocks), 0) + self.assertEqual( + len(scheduler.kv_cache_manager.coordinator.single_type_managers[0]. + num_cached_block), 0) + num_free_blocks = (scheduler.kv_cache_manager.block_pool. + free_block_queue.num_free_blocks) + self.assertEqual( + num_free_blocks, + scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1) + + # NOTE(rob): just the ref count on blocks will be 0. The hash + # value, etc will remain since we lazily evict for prefix cache. + for block in scheduler.kv_cache_manager.block_pool.blocks: + self.assertEqual(block.ref_cnt, 0) + + def test_memory_leak(self): + """Test that we do not have a memory leak.""" + scheduler = self.create_scheduler() + NUM_REQUESTS = 5 + NUM_TOKENS = 10 + MAX_TOKENS = 10 + requests = create_requests(num_requests=NUM_REQUESTS, + num_tokens=NUM_TOKENS, + max_tokens=MAX_TOKENS) + + # Add each request. + for request in requests: + scheduler.add_request(request) + scheduler_output = scheduler.schedule() + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Iterate until done. + while True: + scheduler_output = scheduler.schedule() + if len(scheduler.running) == 0: + break + model_runner_output = make_output(scheduler) + scheduler.update_from_output(scheduler_output, model_runner_output) + + # Confirm no memory leak. + self.assert_scheduler_empty(scheduler) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index b0973b15..3e9a8418 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -129,6 +129,8 @@ class AscendConfig: if self.pd_tp_ratio == 0: raise AssertionError( "Only support P node tp size lagger then D node tp size") + self.SLO_limits_for_dynamic_batch = additional_config.get( + "SLO_limits_for_dynamic_batch", -1) class TorchairGraphConfig: diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 37365341..c117767e 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -105,4 +105,4 @@ class AscendSchedulerConfig(SchedulerConfig): if getattr(self, "scheduler_delay_factor", 0) > 0: raise NotImplementedError( "currently AscendScheduler doesn't support scheduler_delay_factor." - ) + ) \ No newline at end of file diff --git a/vllm_ascend/core/scheduler_dynamic_batch.py b/vllm_ascend/core/scheduler_dynamic_batch.py new file mode 100644 index 00000000..af062d62 --- /dev/null +++ b/vllm_ascend/core/scheduler_dynamic_batch.py @@ -0,0 +1,601 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import os +import time + +import pandas as pd +from vllm.config import VllmConfig +from vllm.distributed.kv_events import KVEventBatch +from vllm.logger import logger +from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput +from vllm.v1.core.sched.request_queue import (SchedulingPolicy, + create_request_queue) +from vllm.v1.core.sched.scheduler import Scheduler +from vllm.v1.engine import EngineCoreEventType +from vllm.v1.kv_cache_interface import KVCacheConfig +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + + +class BudgetRefiner: + """This budget refiner can make dynamic adjustment to the token budget + in the chunked prefill scheduling strategy.""" + + def __init__(self, default_budget, slo_limit=-1) -> None: + self.enabled = slo_limit > 0 + if not self.enabled: + return + logger.info( + "Dynamic batch is enabled with SLO limit: {}, and chunked prefill is forced to be activated because dynamic batch relies on it" + .format(str(slo_limit))) + self.lookup: dict[tuple[int, int], int] = {} + self.context_keys: set[int] = set() + self.dnum_keys: set[int] = set() + self.default_budget = default_budget + self._read_lookup_table(slo_limit) + + def _read_lookup_table(self, slo_limit): + """Load the lookup table for dynamic budget.""" + base_dir = os.path.dirname(os.path.abspath(__file__)) + table_file_path = os.path.join(base_dir, "profile_table.csv") + if not os.path.exists(table_file_path): + # proceed without dynamic batch + logger.error( + "The dynamic batching feature requires the lookup table " + "'profile_table.csv', but it was not found at '%s'. " + "Please download the corresponding table file.", + table_file_path) + self.enabled = False + return + else: + df = pd.read_csv(table_file_path) + grouped = df.groupby(['ctx_len', 'd_num']) + for (ctx_len, d_num), group in grouped: + valid = group[group['cost'] <= slo_limit] + if not valid.empty: + max_row = valid.loc[valid['chunk_size'].idxmax()] + self.lookup[(ctx_len, d_num)] = int(max_row['chunk_size']) + self.context_keys.add(ctx_len) + self.dnum_keys.add(d_num) + self.context_keys = set(sorted(self.context_keys)) + self.dnum_keys = set(sorted(self.dnum_keys)) + + def _align_key(self, value, valid_keys): + """Align the minimum value within the valid_keys that is greater than the value.""" + for k in valid_keys: + if k >= value: + return k + return None + + def _get_max_budget(self, num_deocde_tokens, num_decode): + """Get the maximum budget according to the number of decoding tokens and the decoding requests.""" + aligned_ctx = self._align_key(num_deocde_tokens, self.context_keys) + aligned_dnum = self._align_key(num_decode, self.dnum_keys) + if aligned_ctx is None or aligned_dnum is None: + return self.default_budget + budget = self.lookup.get((aligned_ctx, aligned_dnum), None) + if budget is None: + logger.warn(f"Table miss for ctx,dnum{aligned_ctx, aligned_dnum}") + budget = self.default_budget + # For debug. + # logger.info(f"budget {budget}, ctx,dnum {aligned_ctx, aligned_dnum}, raw ctx,dnum {num_deocde_tokens, num_decode}") + return budget + + def refine_budget(self, running_request, budget): + """Dynamically refine the token budget according to the running request.""" + if not self.enabled: + return budget + # assume all running request will be scheduled. + num_decode_token_lst = [ + req.num_tokens_with_spec \ + for req in running_request \ + if req.num_computed_tokens >= req.num_prompt_tokens ] + num_decode = len(num_decode_token_lst) + if num_decode <= 0: + return budget + num_deocde_tokens = sum(num_decode_token_lst) / num_decode + return self._get_max_budget(num_deocde_tokens, num_decode) + + +class SchedulerDynamicBatch(Scheduler): + """This Scheduler extends vllm's original v1 scheduler + with dynamic batch.""" + + def __init__( + self, + vllm_config: VllmConfig, + kv_cache_config: KVCacheConfig, + structured_output_manager: StructuredOutputManager, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + include_finished_set: bool = False, + log_stats: bool = False, + ) -> None: + super().__init__(vllm_config, kv_cache_config, + structured_output_manager, mm_registry, + include_finished_set, log_stats) + self.running: list[Request] = [] + self.budget_refiner = BudgetRefiner( + default_budget=self.scheduler_config.max_num_batched_tokens, + slo_limit=self.scheduler_config.SLO_limits_for_dynamic_batch) + + def schedule(self) -> SchedulerOutput: + # NOTE: This scheduling algorithm is developed based on the "super.schedule()" + # with the implementations of the dynamic batch and some modifications: + # 1. Token budget can be dynamically refined according to the self.running + # through the BudgetRefiner; + # 2. This scheduling algorithm follows decode-first chunked prefills and FCFS + # strategy, which is slightly different to the "super.schedule()" + # 3. Similar to the "super.schedule()", at each step, the scheduler tries to + # assign tokens to the requests so that each request's num_computed_tokens can + # catch up its num_tokens_with_spec. + # 4. So far, the dynamic batch only supports 910B3 NPU. Further work will include + # more devices and finer optimization strategy. + + scheduled_new_reqs: list[Request] = [] + scheduled_resumed_reqs: list[Request] = [] + scheduled_running_reqs: list[Request] = [] + preempted_reqs: list[Request] = [] + + req_to_new_blocks: dict[str, KVCacheBlocks] = {} + num_scheduled_tokens: dict[str, int] = {} + token_budget = self.max_num_scheduled_tokens + token_budget = self.budget_refiner.refine_budget( + self.running, token_budget) + + # NOTE: We move the prefill requests to the end of the self.running + # list and keep the relative order unchanged. This rearrangement makes this + # scheduling algorithm a strict decode-first chunked prefills. + d_lst = [ + req for req in self.running + if req.num_computed_tokens >= req.num_prompt_tokens + ] + p_lst = [ + req for req in self.running + if req.num_computed_tokens < req.num_prompt_tokens + ] + self.running = d_lst + p_lst + + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_compute_budget = self.max_num_encoder_input_tokens + # Spec decode-related. + scheduled_spec_decode_tokens: dict[str, list[int]] = {} + + # For logging. + scheduled_timestamp = time.monotonic() + + # First, schedule the RUNNING requests. + req_index = 0 + while req_index < len(self.running) and token_budget > 0: + request = self.running[req_index] + + num_new_tokens = (request.num_tokens_with_spec + + request.num_output_placeholders - + request.num_computed_tokens) + if (0 < self.scheduler_config.long_prefill_token_threshold < + num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + num_new_tokens = min(num_new_tokens, token_budget) + + # Make sure the input position does not exceed the max model len. + # This is necessary when using spec decoding. + num_new_tokens = min( + num_new_tokens, + self.max_model_len - 1 - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_compute_budget) + + if num_new_tokens == 0: + # The request cannot be scheduled because one of the following + # reasons: + # 1. No new tokens to schedule. This may happen when + # (1) PP>1 and we have already scheduled all prompt tokens + # but they are not finished yet. + # (2) Async scheduling and the request has reached to either + # its max_total_tokens or max_model_len. + # 2. The encoder budget is exhausted. + # 3. The encoder cache is exhausted. + # NOTE(woosuk): Here, by doing `break` instead of `continue` as + # in v1 scheduler, we strictly follow the FCFS scheduling policy. + req_index += 1 + break + + while True: + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens, + num_lookahead_tokens=self.num_lookahead_tokens) + if new_blocks is None: + # The request cannot be scheduled. + # Preempt the lowest-priority request. + if self.policy == SchedulingPolicy.PRIORITY: + preempted_req = max( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + if preempted_req in scheduled_running_reqs: + scheduled_running_reqs.remove(preempted_req) + else: + preempted_req = self.running.pop() + + self.kv_cache_manager.free(preempted_req) + self.encoder_cache_manager.free(preempted_req) + preempted_req.status = RequestStatus.PREEMPTED + preempted_req.num_computed_tokens = 0 + if self.log_stats: + preempted_req.record_event( + EngineCoreEventType.PREEMPTED, scheduled_timestamp) + + self.waiting.prepend_request(preempted_req) + preempted_reqs.append(preempted_req) + if preempted_req == request: + # No more request to preempt. + can_schedule = False + break + else: + # The request can be scheduled. + can_schedule = True + break + if not can_schedule: + break + assert new_blocks is not None + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_blocks[request.request_id] = new_blocks + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + + # Speculative decode related. + if request.spec_token_ids: + num_scheduled_spec_tokens = (num_new_tokens + + request.num_computed_tokens - + request.num_tokens) + if num_scheduled_spec_tokens > 0: + # Trim spec_token_ids list to num_scheduled_spec_tokens. + del request.spec_token_ids[num_scheduled_spec_tokens:] + scheduled_spec_decode_tokens[request.request_id] = ( + request.spec_token_ids) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Record the LoRAs in scheduled_running_reqs + scheduled_loras: set[int] = set() + if self.lora_config: + scheduled_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(scheduled_loras) <= self.lora_config.max_loras + + # Use a temporary RequestQueue to collect requests that need to be + # skipped and put back at the head of the waiting queue later + skipped_waiting_requests = create_request_queue(self.policy) + + # Next, schedule the WAITING requests. + if not preempted_reqs: + while self.waiting and token_budget > 0: + if len(self.running) == self.max_num_running_reqs: + break + + request = self.waiting.peek_request() + + # KVTransfer: skip request if still waiting for remote kvs. + if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: + is_ready = self._update_waiting_for_remote_kv(request) + if is_ready: + request.status = RequestStatus.WAITING + else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Skip request if the structured output request is still waiting + # for FSM compilation. + if request.status == RequestStatus.WAITING_FOR_FSM: + structured_output_req = request.structured_output_request + if structured_output_req and structured_output_req.grammar: + request.status = RequestStatus.WAITING + else: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Check that adding the request still respects the max_loras + # constraint. + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): + # Scheduling would exceed max_loras, skip. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_external_computed_tokens = 0 + load_kv_async = False + + # Get already-cached tokens. + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ + self.kv_cache_manager.get_computed_blocks( + request) + + # Get externally-cached tokens if using a KVConnector. + if self.connector is not None: + num_external_computed_tokens, load_kv_async = ( + self.connector.get_num_new_matched_tokens( + request, num_new_local_computed_tokens)) + + if num_external_computed_tokens is None: + # The request cannot be scheduled because + # the KVConnector couldn't determine + # the number of matched tokens. + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + # Total computed tokens (local + external). + num_computed_tokens = (num_new_local_computed_tokens + + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. + else: + new_computed_blocks = ( + self.kv_cache_manager.create_empty_block_list()) + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens + + encoder_inputs_to_schedule = None + new_encoder_compute_budget = encoder_compute_budget + + # KVTransfer: loading remote KV, do not allocate for new work. + if load_kv_async: + assert num_external_computed_tokens > 0 + num_new_tokens = 0 + # Number of tokens to be scheduled. + else: + # We use `request.num_tokens` instead of + # `request.num_prompt_tokens` to consider the resumed + # requests, which have output tokens. + num_new_tokens = request.num_tokens - num_computed_tokens + if (0 < self.scheduler_config.long_prefill_token_threshold + < num_new_tokens): + num_new_tokens = ( + self.scheduler_config.long_prefill_token_threshold) + + # chunked prefill has to be enabled explicitly to allow + # pooling requests to be chunked + if not self.scheduler_config.chunked_prefill_enabled and \ + num_new_tokens > token_budget: + self.waiting.pop_request() + skipped_waiting_requests.prepend_request(request) + continue + + num_new_tokens = min(num_new_tokens, token_budget) + assert num_new_tokens > 0 + + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_compute_budget + ) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_compute_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + + # Handles an edge case when P/D Disaggregation + # is used with Spec Decoding where an + # extra block gets allocated which + # creates a mismatch between the number + # of local and remote blocks. + effective_lookahead_tokens = (0 if request.num_computed_tokens + == 0 else + self.num_lookahead_tokens) + + # Determine if we need to allocate cross-attention blocks. + if self.is_encoder_decoder and request.has_encoder_inputs: + # TODO(russellb): For Whisper, we know that the input is + # always padded to the maximum length. If we support other + # encoder-decoder models, this will need to be updated if we + # want to only allocate what is needed. + num_encoder_tokens =\ + self.scheduler_config.max_num_encoder_input_tokens + else: + num_encoder_tokens = 0 + + new_blocks = self.kv_cache_manager.allocate_slots( + request, + num_new_tokens + num_external_computed_tokens, + num_new_local_computed_tokens, + new_computed_blocks, + num_lookahead_tokens=effective_lookahead_tokens, + delay_cache_blocks=load_kv_async, + num_encoder_tokens=num_encoder_tokens, + ) + + if new_blocks is None: + # The request cannot be scheduled. + break + + # KVTransfer: the connector uses this info to determine + # if a load is needed. Note that + # This information is used to determine if a load is + # needed for this request. + if self.connector is not None: + self.connector.update_state_after_alloc( + request, + new_computed_blocks + new_blocks, + num_external_computed_tokens, + ) + + # Request was already popped from self.waiting + # unless it was re-added above due to new_blocks being None. + request = self.waiting.pop_request() + if load_kv_async: + # If loading async, allocate memory and put request + # into the WAITING_FOR_REMOTE_KV state. + skipped_waiting_requests.prepend_request(request) + request.status = RequestStatus.WAITING_FOR_REMOTE_KVS + continue + + req_index += 1 + self.running.append(request) + if self.log_stats: + request.record_event(EngineCoreEventType.SCHEDULED, + scheduled_timestamp) + if request.status == RequestStatus.WAITING: + scheduled_new_reqs.append(request) + elif request.status == RequestStatus.PREEMPTED: + scheduled_resumed_reqs.append(request) + else: + raise RuntimeError( + f"Invalid request status: {request.status}") + + if self.lora_config and request.lora_request: + scheduled_loras.add(request.lora_request.lora_int_id) + req_to_new_blocks[request.request_id] = ( + self.kv_cache_manager.get_blocks(request.request_id)) + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + # Count the number of prefix cached tokens. + if request.num_cached_tokens < 0: + request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_compute_budget = new_encoder_compute_budget + + # Put back any skipped requests at the head of the waiting queue + if skipped_waiting_requests: + self.waiting.prepend_requests(skipped_waiting_requests) + + # Check if the scheduling constraints are satisfied. + total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) + assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens + assert token_budget >= 0 + assert len(self.running) <= self.max_num_running_reqs + # Since some requests in the RUNNING queue may not be scheduled in + # this step, the total number of scheduled requests can be smaller than + # len(self.running). + assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + + len(scheduled_running_reqs) <= len(self.running)) + + # Get the longest common prefix among all requests in the running queue. + # This can be potentially used for cascade attention. + num_common_prefix_blocks = [0] * len( + self.kv_cache_config.kv_cache_groups) + if self.running: + any_request = self.running[0] + num_common_prefix_blocks = ( + self.kv_cache_manager.get_num_common_prefix_blocks( + any_request, len(self.running))) + + # Construct the scheduler output. + new_reqs_data = [ + NewRequestData.from_request( + req, req_to_new_blocks[req.request_id].get_block_ids()) + for req in scheduled_new_reqs + ] + cached_reqs_data = self._make_cached_request_data( + scheduled_running_reqs, + scheduled_resumed_reqs, + num_scheduled_tokens, + scheduled_spec_decode_tokens, + req_to_new_blocks, + ) + scheduled_requests = (scheduled_new_reqs + scheduled_running_reqs + + scheduled_resumed_reqs) + structured_output_request_ids, grammar_bitmask = ( + self.get_grammar_bitmask(scheduled_requests, + scheduled_spec_decode_tokens)) + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=cached_reqs_data, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids=structured_output_request_ids, + grammar_bitmask=grammar_bitmask, + ) + + # NOTE(Kuntai): this function is designed for multiple purposes: + # 1. Plan the KV cache store + # 2. Wrap up all the KV cache load / save ops into an opaque object + # 3. Clear the internal states of the connector + if self.connector is not None: + meta = self.connector.build_connector_meta(scheduler_output) + scheduler_output.kv_connector_metadata = meta + + # collect KV cache events from KV cache manager + events = self.kv_cache_manager.take_events() + + # collect KV cache events from connector + if self.connector is not None: + connector_events = self.connector.take_events() + if connector_events: + if events is None: + events = list(connector_events) + else: + events.extend(connector_events) + + # publish collected KV cache events + if events: + batch = KVEventBatch(ts=time.time(), events=events) + self.kv_event_publisher.publish(batch) + + self._update_after_schedule(scheduler_output) + return scheduler_output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 0b4b641d..d8cf5251 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -314,6 +314,14 @@ class NPUPlatform(Platform): vllm_config.scheduler_config) vllm_config.scheduler_config = recompute_scheduler_config + # Extend original scheduler_config to use SchedulerDynamicBatch. + if ascend_config.SLO_limits_for_dynamic_batch != -1: + vllm_config.scheduler_config.scheduler_cls = ( + "vllm_ascend.core.scheduler_dynamic_batch.SchedulerDynamicBatch" + ) + vllm_config.scheduler_config.chunked_prefill_enabled = True + vllm_config.scheduler_config.SLO_limits_for_dynamic_batch = ascend_config.SLO_limits_for_dynamic_batch + @classmethod def get_attn_backend_cls( cls,