diff --git a/.github/workflows/vllm_ascend_test_pr_light.yaml b/.github/workflows/vllm_ascend_test_pr_light.yaml index 4e40c013..386de82c 100644 --- a/.github/workflows/vllm_ascend_test_pr_light.yaml +++ b/.github/workflows/vllm_ascend_test_pr_light.yaml @@ -139,7 +139,6 @@ jobs: --ignore tests/ut/kv_connector/test_remote_prefill_lifecycle.py \ --ignore tests/ut/kv_connector/test_remote_decode_lifecycle.py \ --ignore tests/ut/kv_connector/test_llmdatadist_connector.py \ - --ignore tests/ut/ops/test_linear.py \ --ignore tests/ut/core/test_scheduler_dynamic_batch.py - name: Upload coverage to Codecov diff --git a/tests/ut/ops/test_linear.py b/tests/ut/ops/test_linear.py index 1b3a7268..c31033e6 100644 --- a/tests/ut/ops/test_linear.py +++ b/tests/ut/ops/test_linear.py @@ -99,7 +99,7 @@ class TestAscendRowParallelLinear(BaseLinearTest): ascend_config._ASCEND_CONFIG = MagicMock() ascend_config._ASCEND_CONFIG.oproj_tensor_parallel_size = 2 - ascend_config._ASCEND_CONFIG.ascend_scheduler_config.enabled = False + ascend_config._ASCEND_CONFIG.recompute_scheduler_enable = False linear = AscendRowParallelLinear( input_size=16, diff --git a/tests/ut/ops/test_vocab_parallel_embedding.py b/tests/ut/ops/test_vocab_parallel_embedding.py index 37ea1af1..531df281 100644 --- a/tests/ut/ops/test_vocab_parallel_embedding.py +++ b/tests/ut/ops/test_vocab_parallel_embedding.py @@ -209,12 +209,7 @@ class TestAscendLogitsProcessor(unittest.TestCase): return_value=torch.randn(1, self.vocab_size)), patch( "vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_gather", - return_value=torch.randn(1, self.vocab_size)), - patch( - "vllm_ascend.core.schedule_config.AscendSchedulerConfig.initialize_from_config", - return_value=MagicMock(max_num_batched_tokens=1000, - max_model_len=512, - enable_chunked_prefill=False)) + return_value=torch.randn(1, self.vocab_size)) ] for p in self.patches: diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 6463ed0b..d51981dc 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -32,7 +32,6 @@ class TestNPUPlatform(TestBase): def mock_vllm_ascend_config(): mock_ascend_config = MagicMock() mock_ascend_config.torchair_graph_config.enabled = False - mock_ascend_config.ascend_scheduler_config.enabled = False mock_ascend_config.enable_shared_expert_dp = False return mock_ascend_config diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e5bb7e00..31bf2190 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -45,11 +45,6 @@ class AscendConfig: self.ascend_compilation_config = AscendCompilationConfig( **ascend_compilation_config) - ascend_scheduler_config = additional_config.get( - "ascend_scheduler_config", {}) - self.ascend_scheduler_config = AscendSchedulerConfig( - ascend_scheduler_config) - # Dump / PrecisionDebugger configuration dump_config_path = additional_config.get("dump_config", None) self.dump_config = DumpConfig(dump_config_path) @@ -255,20 +250,6 @@ class TorchairGraphConfig: ) -class AscendSchedulerConfig: - """ - Configuration Object for ascend_scheduler_config from additional_config - """ - - def __init__(self, ascend_scheduler_config: dict): - self.enabled = ascend_scheduler_config.get("enabled", False) - # Ascend scheduler is based on vllm v0 scheduler, so we should support - # all vllm v0 scheduler configs as well. - for k, v in ascend_scheduler_config.items(): - if not hasattr(self, k): - setattr(self, k, v) - - class DumpConfig: """ Configuration object for dump/PrecisionDebugger settings. diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py deleted file mode 100644 index 32d63cbc..00000000 --- a/vllm_ascend/core/schedule_config.py +++ /dev/null @@ -1,105 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# - -from dataclasses import dataclass, fields -from typing import Type, Union - -from vllm.config import SchedulerConfig - -MAX_INT = 2147483647 - - -@dataclass -class AscendSchedulerConfig(SchedulerConfig): - enable_chunked_prefill: bool = False - max_long_partial_prefills: int = 1 - long_prefill_token_threshold: int = MAX_INT - policy: str = "fcfs" - scheduler_cls: Union[str, Type[object]] = ( - "vllm_ascend.core.scheduler.AscendScheduler") - enable_pd_transfer: bool = False - decode_max_num_seqs: int = 0 - - @classmethod - def initialize_from_config( - cls, - vllm_scheduler_config: SchedulerConfig, - ascend_scheduler_config, - ): - scheduler_config = { - field.name: getattr(vllm_scheduler_config, field.name) - for field in fields(vllm_scheduler_config) if field.init - } - # Override default values into original SchedulerConfig - scheduler_config["enable_chunked_prefill"] = False - scheduler_config["max_long_partial_prefills"] = None - scheduler_config["long_prefill_token_threshold"] = None - scheduler_config["policy"] = "fcfs" - scheduler_config["scheduler_cls"] = ( - "vllm_ascend.core.scheduler.AscendScheduler") - scheduler_config["enable_pd_transfer"] = False - scheduler_config["decode_max_num_seqs"] = 0 - # Override params in original SchedulerConfig with params in ascend_scheduler_config - for k, _ in scheduler_config.items(): - if hasattr(ascend_scheduler_config, k): - scheduler_config[k] = getattr(ascend_scheduler_config, k) - return cls(**scheduler_config) - - def __post_init__(self, *args) -> None: - self.max_num_encoder_input_tokens = self.max_num_batched_tokens - self.encoder_cache_size = self.max_num_batched_tokens - self.chunked_prefill_enabled = self.enable_chunked_prefill - if (self.max_num_batched_tokens < self.max_model_len - and not self.chunked_prefill_enabled): - raise ValueError( - "Ascend scheduler is enabled without chunked prefill feature. " - f"Argument max_num_batched_tokens ({self.max_num_batched_tokens}) is " - f"smaller than max_model_len ({self.max_model_len}). " - "This effectively limits the maximum sequence length to " - "max_num_batched_tokens and makes vLLM reject longer " - "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") - # concurrent partial prefills. Default is 1 meaning not enabled. - if self.max_long_partial_prefills is None: - self.max_long_partial_prefills = 1 - self.long_prefill_token_threshold = MAX_INT - - if self.long_prefill_token_threshold is None or \ - self.long_prefill_token_threshold <= 0: - if self.max_model_len is None: - self.long_prefill_token_threshold = MAX_INT - else: - self.long_prefill_token_threshold = \ - max(1, int(self.max_model_len * 0.04)) - - if self.max_long_partial_prefills < 0: - raise ValueError( - f"max_long_partial_prefills must be non-negative, but got " - f"{self.max_long_partial_prefills}") - if self.long_prefill_token_threshold < 0: - raise ValueError( - f"long_prefill_token_threshold must be non-negative, but got " - f"{self.long_prefill_token_threshold}") - - if self.policy != "fcfs": - raise NotImplementedError( - f"currently AscendScheduler only supports fcfs policy, got {self.policy}" - ) - if getattr(self, "scheduler_delay_factor", 0) > 0: - raise NotImplementedError( - "currently AscendScheduler doesn't support scheduler_delay_factor." - ) diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py deleted file mode 100644 index acc7b8c5..00000000 --- a/vllm_ascend/core/scheduler.py +++ /dev/null @@ -1,592 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -import time -from collections import deque -from typing import Iterable, Optional, Union - -from vllm.config import VllmConfig -from vllm.distributed.kv_events import KVEventBatch -from vllm.logger import logger -from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry -from vllm.utils.math_utils import cdiv -from vllm.v1.core.kv_cache_manager import KVCacheBlocks -from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput -from vllm.v1.core.sched.scheduler import Scheduler -from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs -from vllm.v1.kv_cache_interface import KVCacheConfig -from vllm.v1.outputs import ModelRunnerOutput -from vllm.v1.request import Request, RequestStatus -from vllm.v1.structured_output import StructuredOutputManager - - -class AscendScheduler(Scheduler): - """This Scheduler extends vllm's original v1 scheduler - with prefill-first scheduling strategy.""" - - def _initialize_common(self) -> None: - """Initialize common attributes shared across all versions.""" - self.scheduled_req_ids: set[str] = set() - self.running: list[Request] = [] - self.finished_prefill_reqs: deque[Request] = deque() - - enable_pd_transfer = getattr(self.scheduler_config, - 'enable_pd_transfer', False) - decode_max_num_seqs = getattr(self.scheduler_config, - 'decode_max_num_seqs', 0) - self.phase = "" if not enable_pd_transfer else "prefill" - self.decode_max_num_running_reqs = max(self.max_num_running_reqs, - decode_max_num_seqs) - - def __init__( - self, - vllm_config: VllmConfig, - kv_cache_config: KVCacheConfig, - structured_output_manager: StructuredOutputManager, - block_size: Optional[int] = None, - mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, - include_finished_set: bool = False, - log_stats: bool = False, - ) -> None: - # Call the parent class's __init__ method - super().__init__(vllm_config, kv_cache_config, - structured_output_manager, block_size, mm_registry, - include_finished_set, log_stats) - - # Initialize common attributes - self._initialize_common() - - def schedule(self) -> SchedulerOutput: - if self.scheduler_config.enable_chunked_prefill: - return super().schedule() - scheduled_new_reqs: list[Request] = [] - scheduled_resumed_reqs: list[Request] = [] - scheduled_running_reqs: list[Request] = [] - preempted_reqs: list[Request] = [] - - req_to_new_blocks: dict[str, KVCacheBlocks] = {} - num_scheduled_tokens: dict[str, int] = {} - token_budget = self.max_num_scheduled_tokens - - # Encoder-related. - scheduled_encoder_inputs: dict[str, list[int]] = {} - encoder_budget = self.max_num_encoder_input_tokens - - # Spec decode-related. - scheduled_spec_decode_tokens: dict[str, list[int]] = {} - - # For logging. - scheduled_timestamp = time.monotonic() - - # Record scheduled LoRA requests. - scheduled_loras: set[int] = set() - - # Use a temporary deque to collect requests that need to be skipped - # and put back at the head of the waiting queue later - skipped_waiting_requests: deque[Request] = deque() - - if self.phase == "prefill": - remaining_running_reqs = [] - for request in self.running: - # move request has finished prefill to finished_prefill_reqs - if request.num_tokens > request.num_prompt_tokens: - self.finished_prefill_reqs.append(request) - else: - remaining_running_reqs.append(request) - self.running = remaining_running_reqs - # all request prefilled, change phase to decode - if not self.waiting and not self.running: - self.phase = "decode" - # Skip long prompt requests in prefill stage. - # long_prefill_budget is float('inf') if not use. - if self.vllm_config.scheduler_config.long_prefill_token_threshold == 0: - long_prefill_budget = float('inf') - long_prefill_token_threshold = float('inf') - else: - long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills - long_prefill_token_threshold = self.vllm_config.scheduler_config.long_prefill_token_threshold - - # Schedule prefill requests first. - while self.waiting and token_budget > 0: - if len(self.running) == (self.decode_max_num_running_reqs - if self.phase == "decode" else - self.max_num_running_reqs): - - break - - request = self.waiting[0] - - def skip_cur_request(): - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) - - # P/D: skip request if still waiting for remote kvs. - if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: - is_ready = self._update_waiting_for_remote_kv(request) - if is_ready: - request.status = RequestStatus.WAITING - else: - skip_cur_request() - continue - - # Check that adding the request still respects the max_loras - # constraint. - if (self.lora_config and request.lora_request and - (len(scheduled_loras) == self.lora_config.max_loras - and request.lora_request.lora_int_id not in scheduled_loras)): - # Scheduling would exceed max_loras, skip. - skip_cur_request() - continue - - num_external_computed_tokens = 0 - load_kv_async = False - - # Get already-cached tokens. - if request.num_computed_tokens == 0: - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) - - # Get externally-cached tokens if using a KVConnector. - if self.connector is not None: - num_external_computed_tokens, load_kv_async = ( - self.connector.get_num_new_matched_tokens( - request, num_new_local_computed_tokens)) - - # Total computed tokens (local + external). - num_computed_tokens = (num_new_local_computed_tokens + - num_external_computed_tokens) - else: - # P/D: skip checking prefix cache if loaded from remote kvs. - new_computed_blocks = ( - self.kv_cache_manager.create_empty_block_list()) - num_new_local_computed_tokens = 0 - num_computed_tokens = request.num_computed_tokens - - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget - - # P/D: loading remote KV, do not allocate for new work. - if load_kv_async: - assert num_external_computed_tokens > 0 - num_new_tokens = 0 - blocks = None - # Number of tokens to be scheduled. - else: - prompt_limit = self._get_prompt_limit(request) - # We use `request.num_tokens` instead of - # `request.num_prompt_tokens` to consider the resumed - # requests, which have output tokens. - num_new_tokens = request.num_tokens - num_computed_tokens - max_tokens_in_kvcache = (self.kv_cache_config.num_blocks * - self.block_size) - prompt_limit = min(prompt_limit, max_tokens_in_kvcache) - - # Finish request that exceeds prompt_limit or kv cache size. - if num_new_tokens > prompt_limit: - logger.warning( - "Input prompt (%d tokens) is too long" - " and exceeds limit of %d", - num_new_tokens, - prompt_limit, - ) - request.status = RequestStatus.FINISHED_IGNORED - self.finished_req_ids.add( # type: ignore - request.request_id) # type: ignore - self.waiting.popleft() - continue - - if num_new_tokens > token_budget: - # Scheduling would exceed token_budget, skip. - skip_cur_request() - continue - assert num_new_tokens > 0 - blocks = new_computed_blocks.blocks[0] - - # Schedule encoder inputs. - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget, - _) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) - if num_new_tokens == 0 or len( - encoder_inputs_to_schedule) == 0: - # The request cannot be scheduled. - break - - watermark = getattr(self.scheduler_config, "watermark", 0.01) - if not self._check_watermark_for_prefill(request, num_new_tokens, - blocks, watermark): - # Scheduling would exceed watermark, skip. - skip_cur_request() - continue - - if num_new_tokens > long_prefill_token_threshold \ - and long_prefill_budget <= 0: - skip_cur_request() - continue - - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens + num_external_computed_tokens, - num_new_local_computed_tokens, - new_computed_blocks=new_computed_blocks, - num_lookahead_tokens=self.num_lookahead_tokens, - delay_cache_blocks=load_kv_async) - if new_blocks is None: - # The request cannot be scheduled. - break - - # KVConnector: update internal state after allocation. - # This information is used to determine if a load is - # needed for this request. - if self.connector is not None: - self.connector.update_state_after_alloc( - request, - new_computed_blocks + new_blocks, - num_external_computed_tokens, - ) - - self.waiting.popleft() - if load_kv_async: - # If loading async, allocate memory and put request - # into the WAITING_FOR_REMOTE_KV state. - skipped_waiting_requests.appendleft(request) - request.status = RequestStatus.WAITING_FOR_REMOTE_KVS - continue - - self.running.append(request) - if self.log_stats: - request.record_event(EngineCoreEventType.SCHEDULED, - scheduled_timestamp) - self.scheduled_req_ids.add(request.request_id) - # Check request status. - if request.status == RequestStatus.WAITING: - scheduled_new_reqs.append(request) - elif request.status == RequestStatus.PREEMPTED: - scheduled_resumed_reqs.append(request) - else: - raise RuntimeError(f"Invalid request status: {request.status}") - - if self.lora_config and request.lora_request: - scheduled_loras.add(request.lora_request.lora_int_id) - - req_to_new_blocks[ - request.request_id] = self.kv_cache_manager.get_blocks( - request.request_id) - # Update request info. - num_scheduled_tokens[request.request_id] = num_new_tokens - token_budget -= num_new_tokens - if num_new_tokens > long_prefill_token_threshold: - long_prefill_budget -= 1 - request.status = RequestStatus.RUNNING - request.num_computed_tokens = num_computed_tokens - # Count the number of prefix cached tokens. - if request.num_cached_tokens < 0: - request.num_cached_tokens = num_computed_tokens - - # Encoder-related. - if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) - # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget - - # Put back any skipped requests at the head of the waiting queue - if skipped_waiting_requests: - self.waiting.extendleft(skipped_waiting_requests) - - if self.phase == "decode": - while len( - self.running - ) < self.decode_max_num_running_reqs and self.finished_prefill_reqs: - request = self.finished_prefill_reqs.popleft() - self.running.append(request) - - # If no prefill requests are scheduled, - # Schedule decode requests next. - if len(self.scheduled_req_ids) == 0: - req_index = 0 - while req_index < len(self.running) and token_budget > 0: - request = self.running[req_index] - if request.request_id in self.scheduled_req_ids: - # This request has already been scheduled. - req_index += 1 - continue - - num_new_tokens = (request.num_tokens_with_spec - - request.num_computed_tokens) - assert (request.num_tokens - request.num_computed_tokens) == 1 - num_new_tokens = min(num_new_tokens, token_budget) - # Make sure the input position does not exceed the max model len. - # This is necessary when using spec decoding. - num_new_tokens = min( - num_new_tokens, - self.max_model_len - request.num_computed_tokens) - - # Schedule encoder inputs. - encoder_inputs_to_schedule = None - new_encoder_budget = encoder_budget - if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_budget) - - # Check that adding the request still respects the max_loras - # constraint. - if self.lora_config and request.lora_request and ( - len(scheduled_loras) == self.lora_config.max_loras - and request.lora_request.lora_int_id - not in scheduled_loras): - # Scheduling would exceed max_loras, skip. - num_new_tokens = 0 - - if num_new_tokens == 0: - # The request cannot be scheduled because one of the following - # reason: - # 1. No new tokens to schedule. This may happen when PP>1 and - # we have already scheduled all prompt tokens but they are - # not finished yet. - # 2. Adding the request exceeds the max_loras constraint. - # NOTE(woosuk): Here, by doing `continue` instead of `break`, - # we do not strictly follow the FCFS scheduling policy and - # allow the lower-priority requests to be scheduled. - req_index += 1 - continue - - while True: - new_blocks = self.kv_cache_manager.allocate_slots( - request, - num_new_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) - if new_blocks is None: - # The request cannot be scheduled. - # Preempt the lowest-priority request. - preempted_req = self.running.pop() - self.kv_cache_manager.free(preempted_req) - preempted_req.status = RequestStatus.PREEMPTED - preempted_req.num_computed_tokens = 0 - if self.log_stats: - preempted_req.record_event( - EngineCoreEventType.PREEMPTED, - scheduled_timestamp) - self.waiting.appendleft(preempted_req) - preempted_reqs.append(preempted_req) - if preempted_req == request: - # No more request to preempt. - can_schedule = False - break - else: - # The request can be scheduled. - can_schedule = True - break - if not can_schedule: - break - assert new_blocks is not None - - # Schedule the request. - scheduled_running_reqs.append(request) - self.scheduled_req_ids.add(request.request_id) - req_to_new_blocks[request.request_id] = new_blocks - num_scheduled_tokens[request.request_id] = num_new_tokens - token_budget -= num_new_tokens - req_index += 1 - - # Speculative decode related. - if request.spec_token_ids: - num_scheduled_spec_tokens = (num_new_tokens + - request.num_computed_tokens - - request.num_tokens) - if num_scheduled_spec_tokens > 0: - # Trim spec_token_ids list to num_scheduled_spec_tokens. - del request.spec_token_ids[num_scheduled_spec_tokens:] - scheduled_spec_decode_tokens[request.request_id] = ( - request.spec_token_ids) - - # Encoder-related. - if encoder_inputs_to_schedule: - scheduled_encoder_inputs[request.request_id] = ( - encoder_inputs_to_schedule) - # Allocate the encoder cache. - for i in encoder_inputs_to_schedule: - self.encoder_cache_manager.allocate(request, i) - encoder_budget = new_encoder_budget - - # Record scheduled LoRA requests. - if self.lora_config and request.lora_request: - scheduled_loras.add(request.lora_request.lora_int_id) - - # Check if the scheduling constraints are satisfied. - total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) - assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens - assert token_budget >= 0 - assert len( - self.running - ) <= self.decode_max_num_running_reqs if self.phase == "decode" else self.max_num_running_reqs - assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( - scheduled_running_reqs) <= len(self.running) - - # Get the longest common prefix among all requests in the running queue. - # This can be potentially used for cascade attention. - num_common_prefix_blocks = [0] * len( - self.kv_cache_config.kv_cache_groups) - if self.running: - any_request = self.running[0] - num_common_prefix_blocks = ( - self.kv_cache_manager.get_num_common_prefix_blocks( - any_request.request_id)) - - # Construct the scheduler output. - new_reqs_data = [ - NewRequestData.from_request( - req, req_to_new_blocks[req.request_id].get_block_ids()) - for req in scheduled_new_reqs - ] - - cached_reqs_data = self._make_cached_request_data( - scheduled_running_reqs, scheduled_resumed_reqs, - num_scheduled_tokens, scheduled_spec_decode_tokens, - req_to_new_blocks) - scheduled_cached_reqs = cached_reqs_data - scheduler_output = SchedulerOutput( - scheduled_new_reqs=new_reqs_data, - scheduled_cached_reqs=scheduled_cached_reqs, - num_scheduled_tokens=num_scheduled_tokens, - total_num_scheduled_tokens=total_num_scheduled_tokens, - scheduled_spec_decode_tokens=scheduled_spec_decode_tokens, - scheduled_encoder_inputs=scheduled_encoder_inputs, - num_common_prefix_blocks=num_common_prefix_blocks, - # finished_req_ids is an existing state in the scheduler, - # instead of being newly scheduled in this step. - # It contains the request IDs that are finished in between - # the previous and the current steps. - finished_req_ids=self.finished_req_ids, # type: ignore - free_encoder_mm_hashes=self.encoder_cache_manager. - get_freed_mm_hashes(), - ) - # NOTE(Kuntai): this function is designed for multiple purposes: - # 1. Plan the KV cache store - # 2. Wrap up all the KV cache load / save ops into an opaque object - # 3. Clear the internal states of the connector - if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) - scheduler_output.kv_connector_metadata = meta - - events = self.kv_cache_manager.take_events() - if events: - batch = KVEventBatch(ts=time.time(), events=events) - self.kv_event_publisher.publish(batch) - - # Advance the number of computed tokens for the request AFTER - # the request is scheduled. - # 1. The scheduler_output of the current step has to include the - # original number of scheduled tokens to determine input IDs. - # 2. Advance the number of computed tokens here allowing us to - # schedule the prefill request again immediately in the next - # scheduling step. - # 3. If some tokens (e.g. spec tokens) are rejected later, the number of - # computed tokens will be adjusted in update_from_output. - for req_id, num_scheduled_token in num_scheduled_tokens.items(): - self.requests[req_id].num_computed_tokens += num_scheduled_token - - self.finished_req_ids = set() # type: ignore - return scheduler_output - - def _check_watermark_for_prefill(self, - request, - num_new_tokens, - computed_blocks, - watermark=0.01): - computed_blocks = computed_blocks or [] - watermark_blocks = self.kv_cache_config.num_blocks * watermark - num_computed_tokens = (request.num_computed_tokens + - len(computed_blocks) * self.block_size) - num_required_blocks = cdiv(num_new_tokens + num_computed_tokens, - self.block_size) - req_blocks = self.kv_cache_manager.coordinator.get_blocks( - request.request_id) - num_new_blocks = (num_required_blocks - len(req_blocks[0]) - - len(computed_blocks)) - num_evictable_computed_blocks = sum(1 for blk in computed_blocks - if blk.ref_cnt == 0) - # If number of free blocks is less than water mark after allocating, don't allocate. - if (self.kv_cache_manager.block_pool.get_num_free_blocks() - - num_evictable_computed_blocks - - num_new_blocks) < watermark_blocks: - return False - return True - - def _get_prompt_limit(self, request: Request) -> int: - if (self.scheduler_config.enable_chunked_prefill - and not self.scheduler_config.is_multi_step): - prompt_limit = self.vllm_config.model_config.max_model_len - else: - prompt_limit = min( - self.vllm_config.model_config.max_model_len, - self.scheduler_config.max_num_batched_tokens, - ) - - # Model is fine tuned with long context. Return the fine tuned max_len. - if request.lora_request and request.lora_request.long_lora_max_len: - assert prompt_limit <= request.lora_request.long_lora_max_len - return request.lora_request.long_lora_max_len - else: - return prompt_limit - - def finish_requests( - self, - request_ids: Union[str, Iterable[str]], - finished_status: RequestStatus, - ) -> None: - """Handles the finish signal from outside the scheduler. - - For example, the API server can abort a request when the client - disconnects. - """ - for req_id in request_ids: - request = self.requests.get(req_id) - if request is None: - # Invalid request ID. - continue - if request.status == RequestStatus.RUNNING: - self.scheduled_req_ids.discard(request.request_id) - super().finish_requests(request_ids, finished_status) - - def update_from_output( - self, - scheduler_output: SchedulerOutput, - model_runner_output: ModelRunnerOutput, - ) -> EngineCoreOutputs: - num_scheduled_tokens = scheduler_output.num_scheduled_tokens - - # NOTE(woosuk): As len(self.running) can be up to 1K or more, the below - # loop can be a performance bottleneck. We should do our best to avoid - # expensive operations inside the loop. - for request in self.running: - req_id = request.request_id - num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) - if num_tokens_scheduled == 0: - # The request was not scheduled in this step. - continue - if req_id in self.scheduled_req_ids: - self.scheduled_req_ids.remove(req_id) - - return super().update_from_output(scheduler_output, - model_runner_output) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 5ce3a2b2..9958d06f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -30,6 +30,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) +from vllm_ascend.utils import refresh_block_size # isort: off from vllm_ascend.utils import ( @@ -160,7 +161,6 @@ class NPUPlatform(Platform): model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config cache_config = vllm_config.cache_config - ascend_scheduler_config = ascend_config.ascend_scheduler_config ascend_compilation_config = ascend_config.ascend_compilation_config if ascend_compilation_config: vllm_config.additional_config.setdefault( @@ -307,38 +307,13 @@ class NPUPlatform(Platform): else: parallel_config.worker_cls = "vllm_ascend.worker.worker_v1.NPUWorker" - if cache_config: - if cache_config.block_size is None: - cache_config.block_size = 128 - - if cache_config.enable_prefix_caching or \ - not ascend_scheduler_config.enabled or \ - getattr(ascend_scheduler_config, "enable_chunked_prefill", False): - logger.warning( - "If chunked prefill or prefix caching is enabled, block size must be set to 128." - ) - origin_block_size = cache_config.block_size - cache_config.block_size = 128 - # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. - if model_config and model_config.hf_config.model_type == "qwen3_next": - logger.warning( - "When running qwen3-next model, block_size needs to be restored to its original value." - ) - cache_config.block_size = origin_block_size + refresh_block_size(vllm_config) # Activate custom ops for v1, except on 310P if get_ascend_device_type() != AscendDeviceType._310P: compilation_config.custom_ops = ["all"] - # If ascend_scheduler_config is enabled, - # extents original scheduler_config to use AscendScheduler. - if ascend_config.ascend_scheduler_config.enabled: - from vllm_ascend.core.schedule_config import AscendSchedulerConfig - ascend_scheduler_config = AscendSchedulerConfig.initialize_from_config( - vllm_config.scheduler_config, - ascend_config.ascend_scheduler_config) - vllm_config.scheduler_config = ascend_scheduler_config - elif ascend_config.recompute_scheduler_enable: + if ascend_config.recompute_scheduler_enable: from vllm_ascend.core.recompute_schedule_config import \ RecomputeSchedulerConfig recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config( diff --git a/vllm_ascend/profiling_config.py b/vllm_ascend/profiling_config.py index b6825933..8e0dfadf 100644 --- a/vllm_ascend/profiling_config.py +++ b/vllm_ascend/profiling_config.py @@ -44,11 +44,6 @@ SERVICE_PROFILING_SYMBOLS_YAML = """ handler: msserviceprofiler.vllm_profiler.vllm_v1.batch_hookers:schedule name: batchFrameworkProcessing -- symbol: vllm_ascend.core.scheduler:AscendScheduler.schedule - min_version: "0.9.1" - handler: msserviceprofiler.vllm_profiler.vllm_v1.batch_hookers:schedule - name: batchFrameworkProcessing - - symbol: vllm.v1.core.sched.scheduler:Scheduler._free_request min_version: "0.9.1" handler: msserviceprofiler.vllm_profiler.vllm_v1.batch_hookers:free_request diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 16fcb385..4afa65e1 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -451,8 +451,7 @@ class AscendAttentionTorchairBackendImpl(AttentionImpl): else: raise NotImplementedError( "Torchair graph mode with non-MLA attention backend is still experimental." - "v1 scheduler(chunked prefill) is not supported at this moment. Please" - "setting 'ascend_scheduler_config':{'enabled':true} in additional_config" - "to use ascend scheduler.") + "v1 scheduler(chunked prefill) is not supported at this moment." + ) return output.view(num_tokens, self.hidden_size) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index ae5944d4..85031bf6 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1041,3 +1041,29 @@ def get_flashcomm2_reorgnized_batch_ids(global_tp_size) -> list[list[int]]: reorgnized_batch_ids.append(ranks) return reorgnized_batch_ids + + +def refresh_block_size(vllm_config): + """ + Refresh the block size in cache config. + """ + cache_config = vllm_config.cache_config + scheduler_config = vllm_config.scheduler_config + model_config = vllm_config.model_config + + if not cache_config: + return + + if cache_config.block_size is None: + cache_config.block_size = 128 + + if not scheduler_config or not model_config: + return + + # TODO(MengqingCao): Remove the model_type check, after resolving the hidden error in get_kv_cache_groups. + if not model_config.hf_config.model_type == "qwen3_next" and cache_config.block_size != 128: + if cache_config.enable_prefix_caching or scheduler_config.enable_chunked_prefill: + logger.info( + "Block size is set to 128 if prefix cache or chunked prefill is enabled." + ) + cache_config.block_size = 128 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 3558dd42..6ef7bab1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -332,10 +332,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): # Ascend-specific configurations self.ascend_config = get_ascend_config() - if self.ascend_config.ascend_scheduler_config.enabled: - self.chunked_prefill_enabled = self.scheduler_config.enable_chunked_prefill - else: - self.chunked_prefill_enabled = True self.weight_prefetch_method = WeightPrefetchMethod( self.ascend_config.weight_prefetch_config) # Dump / PrecisionDebugger configuration now comes from AscendConfig @@ -1932,7 +1928,6 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): def _build_attn_state(self, num_reqs, num_scheduled_tokens, num_valid_tokens): - ascend_config = get_ascend_config() if np.array_equal(self.seq_lens_np[:num_reqs], num_scheduled_tokens): attn_state = AscendAttentionState.PrefillNoCache # We assume it is the decode stage, where prefill occurs but only one token is not hit in cache. @@ -1949,7 +1944,7 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): else: attn_state = AscendAttentionState.ChunkedPrefill # splitfuse - elif not ascend_config.ascend_scheduler_config.enabled or self.chunked_prefill_enabled: + elif self.scheduler_config.enable_chunked_prefill: attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit