### What this PR does / why we need it?
1. update `CachedRequestState` as `NewRequestData` changed in
https://github.com/vllm-project/vllm/pull/22570
2. drop maintenance of vllm v0.10.0 in the branch main
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
92ff41abea
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
239 lines
9.4 KiB
Python
239 lines
9.4 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
|
#
|
|
import copy
|
|
|
|
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
|
from vllm.v1.request import RequestStatus
|
|
|
|
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
|
create_model_runner_output,
|
|
create_request, create_scheduler,
|
|
create_vllm_config)
|
|
|
|
|
|
def test_basic_lifecycle():
|
|
"""Test lifecycle of a remote prefill."""
|
|
|
|
vllm_config = create_vllm_config()
|
|
scheduler = create_scheduler(vllm_config)
|
|
|
|
# 2 Full Blocks and 1 Half Block.
|
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
|
START_FREE_BLOCK_QUEUE_SIZE = (
|
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
|
|
|
request = create_request(request_id=1,
|
|
num_tokens=NUM_TOKENS,
|
|
do_remote_prefill=True)
|
|
|
|
scheduler.add_request(request)
|
|
request_id = request.request_id
|
|
|
|
# STEP (1):
|
|
# (1a): schedule()
|
|
scheduler_output = scheduler.schedule()
|
|
|
|
# Nothing running and empty scheduler output.
|
|
assert len(scheduler.running) == 0
|
|
assert len(scheduler_output.scheduled_new_reqs) == 0
|
|
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
|
assert len(scheduler_output.num_scheduled_tokens) == 0
|
|
assert scheduler_output.total_num_scheduled_tokens == 0
|
|
|
|
# Req waiting for KVs with no computed/scheduled toks ...
|
|
assert len(scheduler.waiting) == 1
|
|
assert request in scheduler.waiting
|
|
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
|
assert (request.num_computed_tokens == 0)
|
|
|
|
# ... but should have (uncached) blocks allocated to it.
|
|
block_pool = scheduler.kv_cache_manager.block_pool
|
|
assert (block_pool.free_block_queue.num_free_blocks
|
|
< START_FREE_BLOCK_QUEUE_SIZE)
|
|
assert len(block_pool.cached_block_hash_to_block) == 0
|
|
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0].req_to_blocks[request_id]
|
|
for block in blocks:
|
|
assert block._block_hash is None
|
|
|
|
# (1b): forward()
|
|
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
|
|
|
# (1c): update_from_output()
|
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
|
model_runner_output)
|
|
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
|
|
|
# STEP (2):
|
|
# (2a): schedule(): nothing happens!
|
|
scheduler_output = scheduler.schedule()
|
|
assert len(scheduler.waiting) == 1
|
|
assert len(scheduler.running) == 0
|
|
|
|
# (2b): forward(): request finishes recv.
|
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
KVConnectorOutput # type: ignore # noqa
|
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
finished_recving=[request_id])
|
|
|
|
# (2c): update_from_output():
|
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
|
model_runner_output)
|
|
assert len(scheduler.waiting) == 1
|
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
|
|
|
# STEP (3):
|
|
# (3a): schedule(): this should actually schedule.
|
|
scheduler_output = scheduler.schedule()
|
|
assert len(scheduler.running) == 1
|
|
|
|
# Confirm the block are actually allocated.
|
|
num_hashed_blocks = 0
|
|
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0].req_to_blocks[request_id]
|
|
for block in blocks:
|
|
assert block.ref_cnt == 1
|
|
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
|
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
|
|
|
# Confirm the rest of the prompt is scheduled in this step.
|
|
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
|
num_computed_tokens = scheduled_req.num_computed_tokens
|
|
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
|
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
|
|
|
# (3b): execute_model()
|
|
model_runner_output = create_model_runner_output([request])
|
|
# (3c): update_from_output()
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
# Step (4): Hit EOS.
|
|
scheduler_output = scheduler.schedule()
|
|
model_runner_output = create_model_runner_output([request], use_eos=True)
|
|
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
|
model_runner_output)
|
|
scheduler.schedule()
|
|
|
|
assert_scheduler_empty(scheduler)
|
|
|
|
|
|
def test_no_spurious_prefix_caching():
|
|
"""
|
|
With P/D, blocks can be allocated but uncomputed for
|
|
multiple engine steps. This test confirms that we do
|
|
not accidentally have cache hits against uncomputed
|
|
blocks.
|
|
"""
|
|
|
|
vllm_config = create_vllm_config()
|
|
scheduler = create_scheduler(vllm_config)
|
|
|
|
# 2 and a half full external blocks.
|
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
|
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
|
|
|
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
|
request_remote = create_request(
|
|
request_id=1,
|
|
num_tokens=NUM_TOKENS,
|
|
do_remote_prefill=True,
|
|
use_all_1s_for_prompt_tokens=True,
|
|
)
|
|
|
|
# Schedule the remote prefill request. This should not
|
|
# cause any blocks to be cached.
|
|
scheduler.add_request(request_remote)
|
|
scheduler_output = scheduler.schedule()
|
|
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
|
assert len(scheduler.waiting) == 1
|
|
|
|
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
|
0].req_to_blocks[request_remote.request_id]
|
|
|
|
# Remote blocks should not be cached.
|
|
for block in remote_blocks:
|
|
assert block.ref_cnt == 1
|
|
assert block._block_hash is None
|
|
|
|
|
|
def test_full_block_prompt():
|
|
"""Test that we handle a prompt that is the full block size."""
|
|
|
|
vllm_config = create_vllm_config()
|
|
scheduler = create_scheduler(vllm_config)
|
|
|
|
# 2 Full Blocks and 1 Half Block.
|
|
BLOCK_SIZE = vllm_config.cache_config.block_size
|
|
NUM_EXTERNAL_FULL_BLOCKS = 2
|
|
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
|
|
|
request = create_request(request_id=1,
|
|
num_tokens=NUM_TOKENS,
|
|
do_remote_prefill=True)
|
|
|
|
scheduler.add_request(request)
|
|
request_id = request.request_id
|
|
|
|
# STEP (1): Initialize a recv.
|
|
scheduler_output = scheduler.schedule()
|
|
# All blocks should be allocated.
|
|
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
|
single_type_managers[0].req_to_blocks[request_id])
|
|
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
|
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
# # STEP (2): Recv.
|
|
scheduler_output = scheduler.schedule()
|
|
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
|
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
|
KVConnectorOutput # type: ignore # noqa
|
|
model_runner_output.kv_connector_output = KVConnectorOutput(
|
|
finished_recving=[request_id])
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
assert len(scheduler.waiting) == 1
|
|
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
|
|
|
# # STEP (3): Run as usual.
|
|
scheduler_output = scheduler.schedule()
|
|
|
|
# We need to recompute the final token of the prompt to generate
|
|
# the first new token, so we should not have a new block.
|
|
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
|
single_type_managers[0].req_to_blocks[request_id])
|
|
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
|
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
|
NUM_TOKENS - 1)
|
|
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
|
|
|
model_runner_output = create_model_runner_output([request])
|
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
|
|
|
# # Step (4): Hit EOS.
|
|
scheduler_output = scheduler.schedule()
|
|
model_runner_output = create_model_runner_output([request], use_eos=True)
|
|
scheduler.schedule()
|
|
|
|
assert_scheduler_empty(scheduler)
|