diff --git a/tests/scheduler/test_scheduler.py b/tests/scheduler/test_scheduler.py index 6eddd4f..330cd27 100644 --- a/tests/scheduler/test_scheduler.py +++ b/tests/scheduler/test_scheduler.py @@ -1,8 +1,5 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. -# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py -# Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# This file is a part of the vllm-ascend project. # from typing import Optional diff --git a/tests/singlecard/test_scheduler.py b/tests/singlecard/test_scheduler.py new file mode 100644 index 0000000..6eddd4f --- /dev/null +++ b/tests/singlecard/test_scheduler.py @@ -0,0 +1,396 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Optional + +import pytest +import torch +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig +from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange +from vllm.sampling_params import SamplingParams +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec) +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.request import Request, RequestStatus +from vllm.v1.structured_output import StructuredOutputManager + +from vllm_ascend.core.scheduler import AscendScheduler + +EOS_TOKEN_ID = 50256 + + +def create_scheduler( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, +) -> AscendScheduler: + '''Create scheduler under test. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + :class:`Scheduler` instance + ''' + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_num_batched_tokens, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=10000, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(16, 1, 1, torch.float32, False)) + ], + ) + cache_config.num_gpu_blocks = 10000 + return AscendScheduler( + scheduler_config, + model_config, + cache_config, + lora_config=None, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests(num_requests: int, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt=None, + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=0, + ) + requests.append(request) + return requests + + +def test_add_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + + for i, request in enumerate(requests): + scheduler.add_request(request) + assert request.request_id in scheduler.requests + assert len(scheduler.waiting) == i + 1 + + +def test_finish_request(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_ABORTED) + assert request.request_id not in scheduler.requests + assert len(scheduler.waiting) == 9 - i + + +def test_get_num_unfinished_requests(): + scheduler = create_scheduler() + requests = create_requests(num_requests=10) + for request in requests: + scheduler.add_request(request) + + for i, request in enumerate(requests): + scheduler.finish_requests(request.request_id, + RequestStatus.FINISHED_STOPPED) + assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 + + +@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ + (None, None), + (True, 5), +]) +def test_schedule(enable_prefix_caching: Optional[bool], + prompt_logprobs: Optional[int]): + '''Test scheduling. + Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs + ''' + scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) + requests = create_requests(num_requests=10, + prompt_logprobs=prompt_logprobs) + for request in requests: + scheduler.add_request(request) + + # Test initial scheduling + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == len(requests) + assert len(output.scheduled_cached_reqs) == 0 + assert len(output.finished_req_ids) == 0 + # Verify all requests are scheduled. + for req_id, num_tokens in output.num_scheduled_tokens.items(): + assert num_tokens == len(requests[int(req_id)].prompt_token_ids) + + # Verify requests moved from waiting to running + assert len(scheduler.waiting) == 0 + assert len(scheduler.running) == len(requests) + for i, request in enumerate(requests): + assert scheduler.running[i] == request + + +def test_stop_via_update_from_output(): + """Test stopping behavior through update_from_output""" + scheduler = create_scheduler() + + # Test case 1: Stop on EOS token + requests = create_requests(num_requests=2, max_tokens=10) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 1, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [], + requests[1].request_id: [10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[EOS_TOKEN_ID], + [10, + 11]], # First request hits EOS, second continues + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped, second continues + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] + assert list(requests[1].output_token_ids) == [10, 11] + + # Test case 2: Stop on custom stop token + scheduler = create_scheduler() + requests = create_requests(num_requests=2, + max_tokens=10, + stop_token_ids=[42, 43]) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 2 + }, + total_num_scheduled_tokens=5, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 42], + requests[1].request_id: [13] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 42, 12], + [13, 14]], # First request hits stop token + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped on custom token + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_STOPPED + assert requests[0].stop_reason == 42 + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 42] + assert list(requests[1].output_token_ids) == [13, 14] + + # Test case 3: Stop on max tokens + scheduler = create_scheduler() + requests = create_requests(num_requests=2, max_tokens=2) + for req in requests: + req.num_computed_tokens = req.num_tokens + scheduler.requests[req.request_id] = req + scheduler.running.append(req) + scheduler.scheduled_req_ids.add(req.request_id) + + scheduler_output = SchedulerOutput(scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={ + requests[0].request_id: 3, + requests[1].request_id: 1 + }, + total_num_scheduled_tokens=4, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [10, 11], + requests[1].request_id: [] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in requests], + req_id_to_index={req.request_id: i + for i, req in enumerate(requests)}, + sampled_token_ids=[[10, 11, 12], + [13]], # First request exceeds max_tokens + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify first request stopped due to length + assert len(scheduler.running) == 1 + assert scheduler.running[0].request_id == requests[1].request_id + assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED + assert requests[0].request_id in scheduler.finished_req_ids + assert list(requests[0].output_token_ids) == [10, 11 + ] # Truncated to max_tokens + assert list(requests[1].output_token_ids) == [13] + + # Test case 4: Ignore EOS flag + scheduler = create_scheduler() + requests = create_requests(num_requests=1, max_tokens=10) + requests[0].sampling_params.ignore_eos = True + requests[0].num_computed_tokens = requests[0].num_tokens + scheduler.requests[requests[0].request_id] = requests[0] + scheduler.running.append(requests[0]) + scheduler.scheduled_req_ids.add(requests[0].request_id) + + scheduler_output = SchedulerOutput( + scheduled_new_reqs=[], + scheduled_cached_reqs=[], + num_scheduled_tokens={requests[0].request_id: 3}, + total_num_scheduled_tokens=3, + scheduled_encoder_inputs={}, + scheduled_spec_decode_tokens={ + requests[0].request_id: [EOS_TOKEN_ID, 10] + }, + num_common_prefix_blocks=0, + finished_req_ids=set(), + free_encoder_input_ids=[], + structured_output_request_ids={}, + grammar_bitmask=None) + + model_output = ModelRunnerOutput( + req_ids=[requests[0].request_id], + req_id_to_index={requests[0].request_id: 0}, + sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}) + + scheduler.update_from_output(scheduler_output, model_output) + + # Verify request continues past EOS + assert len(scheduler.running) == 1 + assert not requests[0].is_finished() + assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11] diff --git a/vllm_ascend/core/schedule_config.py b/vllm_ascend/core/schedule_config.py index 4861bfc..9194411 100644 --- a/vllm_ascend/core/schedule_config.py +++ b/vllm_ascend/core/schedule_config.py @@ -1,6 +1,5 @@ # # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# This file is a part of the vllm-ascend project. # from dataclasses import dataclass, fields diff --git a/vllm_ascend/core/scheduler.py b/vllm_ascend/core/scheduler.py index 348d7e7..22b503e 100644 --- a/vllm_ascend/core/scheduler.py +++ b/vllm_ascend/core/scheduler.py @@ -1,3 +1,19 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# from collections import deque from vllm.logger import init_logger