Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -251,7 +251,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
kv_cache = [k_cache, v_cache]
|
||||
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
||||
@@ -261,7 +264,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
||||
layer = self.layer
|
||||
layer.quant_method = MagicMock()
|
||||
layer.quant_method.apply.return_value = kv_cache
|
||||
layer.quant_method.apply.return_value = ret_value
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
|
||||
42
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
42
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \
|
||||
LLMDataDistCMgrConnectorMetadata
|
||||
|
||||
|
||||
def test_basic_inferface():
|
||||
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.requests) == 1
|
||||
assert request_id in kv_connector_metadata.requests
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
163
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
163
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and blocks should be freed
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_id]
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still works with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||
|
||||
scheduler.add_request(request_remote_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
|
||||
use_eos=True)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||
1))
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_remote.request_id]
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
248
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
248
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,248 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
if vllm_version_is("0.9.1"):
|
||||
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
||||
else:
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert (request.num_computed_tokens == 0)
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
201
tests/ut/kv_connector/utils.py
Normal file
201
tests/ut/kv_connector/utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 1024,
|
||||
block_size: int = 128,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LLMDataDistCMgrConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_module_path=
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 128,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
if use_all_1s_for_prompt_tokens:
|
||||
prompt_token_ids = [1] * num_tokens
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
**({
|
||||
"pooling_params": []
|
||||
} if not vllm_version_is("0.9.1") else {}),
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
**({
|
||||
"pooler_output": []
|
||||
} if not vllm_version_is("0.9.1") else {}),
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
Reference in New Issue
Block a user