diff --git a/tests/ut/core/test_schedule_config.py b/tests/ut/core/test_schedule_config.py index b135370..17c162f 100644 --- a/tests/ut/core/test_schedule_config.py +++ b/tests/ut/core/test_schedule_config.py @@ -75,13 +75,6 @@ class TestAscendSchedulerConfig(TestBase): str(context.exception), ) - def test_not_implemented_multimodal(self): - with self.assertRaises(NotImplementedError) as context: - AscendSchedulerConfig.initialize_from_config( - SchedulerConfig(is_multimodal_model=True), {}) - self.assertIn("currently AscendScheduler only supports LLM models", - str(context.exception)) - def test_not_implemented_send_delta_data(self): with self.assertRaises(NotImplementedError) as context: AscendSchedulerConfig.initialize_from_config( @@ -118,6 +111,11 @@ class TestAscendSchedulerConfig(TestBase): self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192) self.assertEqual(ascend_config.encoder_cache_size, 8192) + def test_valid_config_with_multimodal(self): + config = AscendSchedulerConfig.initialize_from_config( + SchedulerConfig(is_multimodal_model=True), {}) + self.assertTrue(config.is_multimodal_model) + def test_valid_config_with_chunked_prefill(self): ascend_config = AscendSchedulerConfig.initialize_from_config( self.basic_scheduler_config, diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index c2e21c0..095f392 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -6,7 +6,8 @@ from unittest.mock import MagicMock, patch import torch from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, SchedulerConfig, SpeculativeConfig, VllmConfig) -from vllm.multimodal.inputs import PlaceholderRange +from vllm.multimodal.inputs import (MultiModalFeatureSpec, + MultiModalKwargsItem, PlaceholderRange) from vllm.sampling_params import SamplingParams from vllm.utils import sha256 from vllm.v1.core.kv_cache_utils import (get_request_block_hasher, @@ -49,11 +50,23 @@ def create_requests( prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): + mm_features = [] + if mm_positions is not None: + mm_position = mm_positions[i] + for j, position in enumerate(mm_position): + identifier = f"hash{i}_{j}" + mm_feature = MultiModalFeatureSpec( + data=MultiModalKwargsItem.dummy("dummy_m"), + mm_position=position, + identifier=identifier, + modality="image") + mm_features.append(mm_feature) request = Request(request_id=f"{i}", prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, eos_token_id=EOS_TOKEN_ID, pooling_params=None, + mm_features=mm_features if mm_features else None, block_hasher=get_request_block_hasher( block_size, hash_fn)) requests.append(request) @@ -86,7 +99,7 @@ class TestAscendScheduler(TestBase): @patch("vllm.config.VllmConfig.__post_init__", MagicMock()) @patch('vllm.v1.core.sched.scheduler.compute_encoder_budget') def create_scheduler(self, mock_compute_encoder_budget): - mock_compute_encoder_budget.return_value = [10, 20] + mock_compute_encoder_budget.return_value = [100, 100] use_kv_connector = False block_size = 16 @@ -233,6 +246,39 @@ class TestAscendScheduler(TestBase): for i, request in enumerate(requests): self.assertEqual(scheduler.running[i], request) + def test_schedule_multimodal_requests(self): + scheduler = self.create_scheduler() + scheduler.scheduler_config.chunked_prefill_enabled = False + mm_positions = [[PlaceholderRange(offset=i, length=10)] + for i in range(10)] + requests = create_requests( + num_requests=10, + mm_positions=mm_positions, + ) + for request in requests: + scheduler.add_request(request) + + output = scheduler.schedule() + self.assertEqual(len(output.scheduled_new_reqs), len(requests)) + self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0) + self.assertEqual(len(output.finished_req_ids), 0) + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + self.assertEqual(num_tokens, + len(requests[int(req_id)].prompt_token_ids)) + self.assertEqual(len(output.scheduled_encoder_inputs), len(requests)) + for req_id, encoder_input in output.scheduled_encoder_inputs.items(): + assert len(encoder_input) == 1 + + # Verify requests moved from waiting to running + self.assertEqual(len(scheduler.waiting), 0) + self.assertEqual(len(scheduler.running), len(requests)) + for i, request in enumerate(requests): + self.assertEqual(scheduler.running[i], request) + def test_schedule_enable_prefix_caching(self): '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 257657a..83d0675 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -71,9 +71,6 @@ class AscendSchedulerConfig(SchedulerConfig): raise NotImplementedError( f"currently AscendScheduler only supports fcfs policy, got {self.policy}" ) - if self.is_multimodal_model: - raise NotImplementedError( - "currently AscendScheduler only supports LLM models.") if self.send_delta_data: raise NotImplementedError( "currently AscendScheduler doesn't support send_delta_data.") diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index c3ac722..cc6822f 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -72,6 +72,11 @@ class AscendScheduler(Scheduler): req_to_new_blocks: dict[str, KVCacheBlocks] = {} num_scheduled_tokens: dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + + # Encoder-related. + scheduled_encoder_inputs: dict[str, list[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens + # Spec decode-related. scheduled_spec_decode_tokens: dict[str, list[int]] = {} @@ -155,6 +160,9 @@ class AscendScheduler(Scheduler): num_new_local_computed_tokens = 0 num_computed_tokens = request.num_computed_tokens + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + # P/D: loading remote KV, do not allocate for new work. if load_kv_async: assert num_external_computed_tokens > 0 @@ -192,6 +200,16 @@ class AscendScheduler(Scheduler): assert num_new_tokens > 0 blocks = new_computed_blocks.blocks[0] + # Schedule encoder inputs. + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + watermark = getattr(self.scheduler_config, "watermark", 0.01) if not self._check_watermark_for_prefill(request, num_new_tokens, blocks, watermark): @@ -256,6 +274,15 @@ class AscendScheduler(Scheduler): if request.num_cached_tokens < 0: request.num_cached_tokens = num_computed_tokens + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: self.waiting.extendleft(skipped_waiting_requests) @@ -287,6 +314,16 @@ class AscendScheduler(Scheduler): num_new_tokens = min( num_new_tokens, self.max_model_len - request.num_computed_tokens) + + # Schedule encoder inputs. + encoder_inputs_to_schedule = None + new_encoder_budget = encoder_budget + if request.has_encoder_inputs: + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, request.num_computed_tokens, num_new_tokens, + encoder_budget) + # Check that adding the request still respects the max_loras # constraint. if self.lora_config and request.lora_request and ( @@ -358,6 +395,15 @@ class AscendScheduler(Scheduler): scheduled_spec_decode_tokens[request.request_id] = ( request.spec_token_ids) + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget + # Record scheduled LoRA requests. if self.lora_config and request.lora_request: scheduled_loras.add(request.lora_request.lora_int_id) @@ -401,7 +447,7 @@ class AscendScheduler(Scheduler): num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs={}, + scheduled_encoder_inputs=scheduled_encoder_inputs, num_common_prefix_blocks=num_common_prefix_blocks, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step.