# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/blob/main/tests/models/utils.py # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Optional import pytest import torch from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheTensor) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from vllm_ascend.core.scheduler import AscendScheduler EOS_TOKEN_ID = 50256 def create_scheduler( model: str = "facebook/opt-125m", max_num_seqs: int = 16, max_num_batched_tokens: int = 8192, enable_prefix_caching: Optional[bool] = None, long_prefill_token_threshold: int = 0, disable_chunked_mm_input: bool = False, ) -> AscendScheduler: '''Create scheduler under test. Args: model: model under test max_num_seqs: max sequences to schedule max_num_batch_tokens: max num tokens to batch enable_prefix_caching: optionally force APC config (True/False) or use default (None) Returns: :class:`Scheduler` instance ''' scheduler_config = SchedulerConfig( max_num_seqs=max_num_seqs, max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_num_batched_tokens, long_prefill_token_threshold=long_prefill_token_threshold, disable_chunked_mm_input=disable_chunked_mm_input, ) model_config = ModelConfig( model=model, task="auto", tokenizer=model, tokenizer_mode="auto", trust_remote_code=True, dtype="float16", seed=42, ) # Cache config, optionally force APC kwargs_cache = ({} if enable_prefix_caching is None else { 'enable_prefix_caching': enable_prefix_caching }) cache_config = CacheConfig( block_size=16, gpu_memory_utilization=0.9, swap_space=0, cache_dtype="auto", **kwargs_cache, ) vllm_config = VllmConfig(scheduler_config=scheduler_config, model_config=model_config, cache_config=cache_config) kv_cache_config = KVCacheConfig( num_blocks=10000, # A large number of blocks to hold all requests kv_cache_tensors=[KVCacheTensor(size=1024, shared_by=[1])], kv_cache_groups=[ KVCacheGroupSpec(['layer'], FullAttentionSpec(16, 1, 1, torch.float32, False, None)) ], ) cache_config.num_gpu_blocks = 10000 return AscendScheduler( vllm_config, kv_cache_config=kv_cache_config, log_stats=True, structured_output_manager=StructuredOutputManager(vllm_config), ) def create_requests(num_requests: int, num_tokens: int = 10, mm_positions: Optional[list[PlaceholderRange]] = None, max_tokens: int = 16, stop_token_ids: Optional[list[int]] = None, prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams(ignore_eos=False, max_tokens=max_tokens, stop_token_ids=stop_token_ids, prompt_logprobs=prompt_logprobs) requests = [] for i in range(num_requests): if mm_positions is not None: mm_position = mm_positions[i] mm_inputs = [MultiModalKwargs({})] * len(mm_position) else: mm_position = None mm_inputs = None request = Request( request_id=f"{i}", prompt_token_ids=[i] * num_tokens, sampling_params=sampling_params, multi_modal_inputs=mm_inputs, multi_modal_placeholders=mm_position, multi_modal_hashes=None, eos_token_id=EOS_TOKEN_ID, pooling_params=None, ) requests.append(request) return requests def test_add_requests(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for i, request in enumerate(requests): scheduler.add_request(request) assert request.request_id in scheduler.requests assert len(scheduler.waiting) == i + 1 def test_finish_request(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for request in requests: scheduler.add_request(request) for i, request in enumerate(requests): scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_ABORTED) assert request.request_id not in scheduler.requests assert len(scheduler.waiting) == 9 - i def test_get_num_unfinished_requests(): scheduler = create_scheduler() requests = create_requests(num_requests=10) for request in requests: scheduler.add_request(request) for i, request in enumerate(requests): scheduler.finish_requests(request.request_id, RequestStatus.FINISHED_STOPPED) assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1 @pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [ (None, None), (True, 5), ]) def test_schedule(enable_prefix_caching: Optional[bool], prompt_logprobs: Optional[int]): '''Test scheduling. Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs ''' scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) requests = create_requests(num_requests=10, prompt_logprobs=prompt_logprobs) for request in requests: scheduler.add_request(request) # Test initial scheduling output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) assert output.scheduled_cached_reqs.num_reqs == 0 assert len(output.finished_req_ids) == 0 # Verify all requests are scheduled. for req_id, num_tokens in output.num_scheduled_tokens.items(): assert num_tokens == len(requests[int(req_id)].prompt_token_ids) # Verify requests moved from waiting to running assert len(scheduler.waiting) == 0 assert len(scheduler.running) == len(requests) for i, request in enumerate(requests): assert scheduler.running[i] == request def test_stop_via_update_from_output(): """Test stopping behavior through update_from_output""" scheduler = create_scheduler() # Test case 1: Stop on EOS token requests = create_requests(num_requests=2, max_tokens=10) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={ requests[0].request_id: 1, requests[1].request_id: 2 }, scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ req.request_id: i for i, req in enumerate(requests) }, sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped, second continues assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_STOPPED assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID] assert list(requests[1].output_token_ids) == [10, 11] # Test case 2: Stop on custom stop token scheduler = create_scheduler() requests = create_requests(num_requests=2, max_tokens=10, stop_token_ids=[42, 43]) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={ requests[0].request_id: 3, requests[1].request_id: 2 }, scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=5, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ req.request_id: i for i, req in enumerate(requests) }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped on custom token assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_STOPPED assert requests[0].stop_reason == 42 assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [10, 42] assert list(requests[1].output_token_ids) == [13, 14] # Test case 3: Stop on max tokens scheduler = create_scheduler() requests = create_requests(num_requests=2, max_tokens=2) for req in requests: req.num_computed_tokens = req.num_tokens scheduler.requests[req.request_id] = req scheduler.running.append(req) scheduler.scheduled_req_ids.add(req.request_id) scheduler_output = SchedulerOutput(scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={ requests[0].request_id: 3, requests[1].request_id: 1 }, scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=4, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[req.request_id for req in requests], req_id_to_index={ req.request_id: i for i, req in enumerate(requests) }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) # Verify first request stopped due to length assert len(scheduler.running) == 1 assert scheduler.running[0].request_id == requests[1].request_id assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED assert requests[0].request_id in scheduler.finished_req_ids assert list(requests[0].output_token_ids) == [10, 11 ] # Truncated to max_tokens assert list(requests[1].output_token_ids) == [13] # Test case 4: Ignore EOS flag scheduler = create_scheduler() requests = create_requests(num_requests=1, max_tokens=10) requests[0].sampling_params.ignore_eos = True requests[0].num_computed_tokens = requests[0].num_tokens scheduler.requests[requests[0].request_id] = requests[0] scheduler.running.append(requests[0]) scheduler.scheduled_req_ids.add(requests[0].request_id) scheduler_output = SchedulerOutput( scheduled_new_reqs=[], scheduled_cached_reqs=[], num_scheduled_tokens={requests[0].request_id: 3}, scheduled_spec_decode_tokens={}, total_num_scheduled_tokens=3, scheduled_encoder_inputs={}, num_common_prefix_blocks=0, finished_req_ids=set(), free_encoder_input_ids=[], structured_output_request_ids={}, grammar_bitmask=None) model_output = ModelRunnerOutput( req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) scheduler.update_from_output(scheduler_output, model_output) # Verify request continues past EOS assert len(scheduler.running) == 1 assert not requests[0].is_finished() assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]