diff --git a/tests/ut/core/test_scheduler.py b/tests/ut/core/test_scheduler.py index 6680a25..635e08d 100644 --- a/tests/ut/core/test_scheduler.py +++ b/tests/ut/core/test_scheduler.py @@ -295,24 +295,25 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 1, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [], - requests[1].request_id: [10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) if vllm_version_is("0.10.1.1"): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -327,6 +328,24 @@ class TestAscendScheduler(TestBase): prompt_logprobs_dict={}, pooler_output=[]) else: + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -363,25 +382,25 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 2 - }, - total_num_scheduled_tokens=5, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: - [10, 42], - requests[1].request_id: [13] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) if vllm_version_is("0.10.1.1"): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -395,6 +414,24 @@ class TestAscendScheduler(TestBase): prompt_logprobs_dict={}, pooler_output=[]) else: + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -429,26 +466,25 @@ class TestAscendScheduler(TestBase): scheduler.running.append(req) req.status = RequestStatus.RUNNING - scheduler_output = SchedulerOutput(scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={ - requests[0].request_id: 3, - requests[1].request_id: 1 - }, - total_num_scheduled_tokens=4, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: - [10, 11], - requests[1].request_id: [] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - if vllm_version_is("0.10.1.1"): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -462,6 +498,24 @@ class TestAscendScheduler(TestBase): prompt_logprobs_dict={}, pooler_output=[]) else: + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ @@ -493,22 +547,21 @@ class TestAscendScheduler(TestBase): scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) - scheduler_output = SchedulerOutput( - scheduled_new_reqs=[], - scheduled_cached_reqs=[], - num_scheduled_tokens={requests[0].request_id: 3}, - total_num_scheduled_tokens=3, - scheduled_encoder_inputs={}, - scheduled_spec_decode_tokens={ - requests[0].request_id: [EOS_TOKEN_ID, 10] - }, - num_common_prefix_blocks=0, - finished_req_ids=set(), - free_encoder_input_ids=[], - structured_output_request_ids={}, - grammar_bitmask=None) - if vllm_version_is("0.10.1.1"): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, @@ -519,6 +572,20 @@ class TestAscendScheduler(TestBase): pooler_output=[]) else: + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_mm_hashes=[], + structured_output_request_ids={}, + grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, diff --git a/tests/ut/worker/test_input_batch.py b/tests/ut/worker/test_input_batch.py index 320725c..a72dbdc 100644 --- a/tests/ut/worker/test_input_batch.py +++ b/tests/ut/worker/test_input_batch.py @@ -215,6 +215,7 @@ def _construct_cached_request_state(req_id_suffix: int): generator=None, num_computed_tokens=len(output_token_ids), output_token_ids=output_token_ids, + mm_hashes=None, ) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 627d5ea..c9b9a64 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -385,23 +385,44 @@ class AscendScheduler(Scheduler): req_to_new_blocks) scheduled_cached_reqs = cached_reqs_data - scheduler_output = SchedulerOutput( - scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=scheduled_cached_reqs, - num_scheduled_tokens=num_scheduled_tokens, - total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs={}, - num_common_prefix_blocks=num_common_prefix_blocks, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between - # the previous and the current steps. - finished_req_ids=self.finished_req_ids, # type: ignore - free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), - structured_output_request_ids={}, - grammar_bitmask=None, - ) + if vllm_version_is("0.10.1.1"): + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=scheduled_cached_reqs, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_input_ids=self.encoder_cache_manager. + get_freed_ids(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) + else: + scheduler_output = SchedulerOutput( + scheduled_new_reqs=new_reqs_data, + scheduled_cached_reqs=scheduled_cached_reqs, + num_scheduled_tokens=num_scheduled_tokens, + total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, + scheduled_encoder_inputs={}, + num_common_prefix_blocks=num_common_prefix_blocks, + # finished_req_ids is an existing state in the scheduler, + # instead of being newly scheduled in this step. + # It contains the request IDs that are finished in between + # the previous and the current steps. + finished_req_ids=self.finished_req_ids, # type: ignore + free_encoder_mm_hashes=self.encoder_cache_manager. + get_freed_mm_hashes(), + structured_output_request_ids={}, + grammar_bitmask=None, + ) # NOTE(Kuntai): this function is designed for multiple purposes: # 1. Plan the KV cache store diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3df2613..03486b0 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -193,7 +193,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Lazy initialization, these will be set after __init__ self.kv_caches: List[torch.Tensor] = [] - self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} + # TODO: remove Dict[str, Dict[int, torch.Tensor]] type after 0.10.1.1 + self.encoder_cache: Union[Dict[str, Dict[int, torch.Tensor]], + Dict[str, torch.Tensor]] = {} self.attn_mask = None self.attn_state = None self.requests: Dict[str, CachedRequestState] = {} @@ -381,7 +383,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Remove finished requests from the cached states. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) - self.encoder_cache.pop(req_id, None) + if vllm_version_is("0.10.1.1"): + self.encoder_cache.pop(req_id, None) # Remove the finished requests from the persistent batch. # NOTE(woosuk): There could be an edge case where finished_req_ids and # scheduled_req_ids overlap. This happens when a request is aborted and @@ -390,15 +393,17 @@ class NPUModelRunner(LoRAModelRunnerMixin): # and handling the second as a new request. for req_id in scheduler_output.finished_req_ids: self.input_batch.remove_request(req_id) - - # Free the cached encoder outputs. - for req_id, input_id in scheduler_output.free_encoder_input_ids: - encoder_outputs = self.encoder_cache.get(req_id) - if encoder_outputs is not None: - encoder_outputs.pop(input_id, None) - if not encoder_outputs: - self.encoder_cache.pop(req_id, None) - + if vllm_version_is("0.10.1.1"): + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) + else: + for mm_hash in scheduler_output.free_encoder_mm_hashes: + self.encoder_cache.pop(mm_hash, None) # Remove the unscheduled requests from the persistent batch. # NOTE(woosuk): The unscheduled requests are either preempted requests # or running requests that are not scheduled in this step. We remove @@ -447,6 +452,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], lora_request=new_req_data.lora_request, + **({ + "mm_hashes": new_req_data.mm_hashes + } if not vllm_version_is("0.10.1.1") else { + "mm_hashes": None + }), ) # Only relevant for models using M-RoPE (e.g, Qwen2-VL) @@ -882,15 +892,25 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Batch the multi-modal inputs. mm_kwargs = list[MultiModalKwargsItem]() - req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + if vllm_version_is("0.10.1.1"): + req_ids_pos = list[tuple[str, int, PlaceholderRange]]() + else: + mm_hashes_pos = list[tuple[str, PlaceholderRange]]() for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): req_state = self.requests[req_id] - - for mm_input_id in encoder_input_ids: - mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) - req_ids_pos.append( - (req_id, mm_input_id, req_state.mm_positions[mm_input_id])) - + if vllm_version_is("0.10.1.1"): + for mm_input_id in encoder_input_ids: + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + req_ids_pos.append((req_id, mm_input_id, + req_state.mm_positions[mm_input_id])) + else: + for mm_input_id in encoder_input_ids: + # TODO remove this assert after 0.10.1.1 + assert req_state.mm_hashes is not None + mm_hash = req_state.mm_hashes[mm_input_id] + mm_kwargs.append(req_state.mm_kwargs[mm_input_id]) + mm_hashes_pos.append( + (mm_hash, req_state.mm_positions[mm_input_id])) # Batch mm inputs as much as we can: if a request in the batch has # multiple modalities or a different modality than the previous one, # we process it separately to preserve item order. @@ -921,19 +941,26 @@ class NPUModelRunner(LoRAModelRunnerMixin): for output in curr_group_outputs: encoder_outputs.append(output) + if vllm_version_is("0.10.1.1"): + # Cache the encoder outputs. + for (req_id, input_id, pos_info), output in zip( + req_ids_pos, + encoder_outputs, + ): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} - # Cache the encoder outputs. - for (req_id, input_id, pos_info), output in zip( - req_ids_pos, - encoder_outputs, - ): - if req_id not in self.encoder_cache: - self.encoder_cache[req_id] = {} - - self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.encoder_cache[req_id][input_id] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) + else: + for (mm_hash, pos_info), output in zip(mm_hashes_pos, + encoder_outputs): + self.encoder_cache[mm_hash] = scatter_mm_placeholders( + output, + is_embed=pos_info.is_embed, + ) def _gather_mm_embeddings( self, @@ -946,6 +973,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): req_state = self.requests[req_id] num_computed_tokens = req_state.num_computed_tokens mm_positions = req_state.mm_positions + if not vllm_version_is("0.10.1.1"): + mm_hashes = req_state.mm_hashes for i, pos_info in enumerate(mm_positions): start_pos = pos_info.offset num_encoder_tokens = pos_info.length @@ -963,13 +992,26 @@ class NPUModelRunner(LoRAModelRunnerMixin): continue start_idx = max(num_computed_tokens - start_pos, 0) - end_idx = min( - num_computed_tokens - start_pos + num_scheduled_tokens, - num_encoder_tokens) - assert start_idx < end_idx - assert req_id in self.encoder_cache - assert i in self.encoder_cache[req_id] - encoder_output = self.encoder_cache[req_id][i] + if vllm_version_is("0.10.1.1"): + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + else: + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens, + ) + assert start_idx < end_idx + # TODO remove this assert after 0.10.1.1 + assert mm_hashes is not None + mm_hash = mm_hashes[i] + encoder_output = self.encoder_cache.get(mm_hash, None) + assert encoder_output is not None,\ + f"Encoder cache miss for {mm_hash}." if (is_embed := pos_info.is_embed) is not None: is_embed = is_embed[start_idx:end_idx] diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index 380dde4..d4a6298 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -47,6 +47,8 @@ class CachedRequestState: prompt_token_ids: list[int] mm_kwargs: list[MultiModalKwargsItem] mm_positions: list[PlaceholderRange] + # TODO: remove Optional after 0.10.1.1 + mm_hashes: Optional[list[str]] sampling_params: Optional[SamplingParams] pooling_params: Optional[PoolingParams] generator: Optional[torch.Generator]