Disaggregate prefill for kv cache register style (#950)
### What this PR does / why we need it?
This PR adopt `LLMDataDist` for kv cache register and `pull_blocks`
style disaggregate prefill implementation. The interface implementation
mainly follows the design of NIXL PR
https://github.com/vllm-project/vllm/pull/17751/files#diff-7eaad0b7dee0626bf29d10081b0f0c5e3ea15a4af97e7b182a4e0d35f8346953
.
This PR can be test with the following step:
- Generate the rank table for all machine.
- execute`toy_proxy.py` to launch the disaggregate prefill proxy server,
specify the prefill ip, port and the decode ip, port
- Run the prefill server and decode server.
- send the request to the disaggregate prefill proxy
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.9.2
- vLLM main:
8d0a01a5f2
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Signed-off-by: liziyu179 <3475441767@qq.com>
Signed-off-by: underfitc <hucong24@huawei.com>
Signed-off-by: zouyida2052 <zouyida@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: underfituu <hzhucong@163.com>
Co-authored-by: machenglong <machenglong_yewu@cmss.chinamobile.com>
Co-authored-by: liziyu179 <3475441767@qq.com>
Co-authored-by: underfitc <hucong24@huawei.com>
Co-authored-by: zouyida2052 <zouyida@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: underfituu <hzhucong@163.com>
This commit is contained in:
@@ -23,12 +23,18 @@ Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason=
|
||||
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
|
||||
)
|
||||
@patch.dict(
|
||||
os.environ, {
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
@@ -54,6 +60,11 @@ def test_generate_with_allgather():
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
True,
|
||||
reason=
|
||||
"Current disaggregated pd implementation may cause memory pulse, which will cause this test OOM, skip this test until the ringmla is ready "
|
||||
)
|
||||
@patch.dict(os.environ, {
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
"TASK_QUEUE_ENABLE": "1"
|
||||
|
||||
@@ -23,6 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from vllm import SamplingParams
|
||||
from vllm.model_executor.models.registry import ModelRegistry
|
||||
@@ -93,6 +94,10 @@ def test_models_distributed_DeepSeek_dbo():
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"deepseek dbo dose not consider the support on half precision float, will enable this ut after we actually support it"
|
||||
)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DBO": "1"})
|
||||
def test_models_distributed_DeepSeekV3_dbo():
|
||||
example_prompts = ["The president of the United States is"] * 41
|
||||
@@ -113,6 +118,7 @@ def test_models_distributed_DeepSeekV3_dbo():
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Due to OOM,waiting for 1311pr to merge in")
|
||||
def test_models_distributed_DeepSeek_W8A8():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
||||
141
tests/e2e/pd_disaggreate/run_edge_case_test.sh
Normal file
141
tests/e2e/pd_disaggreate/run_edge_case_test.sh
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
export LCCL_DETERMINISTIC=1
|
||||
export HCCL_DETERMINISTIC=true
|
||||
export CLOSE_MATMUL_K_SHIFT=1
|
||||
export VLLM_USE_V1=1
|
||||
|
||||
set -xe
|
||||
|
||||
# Models to run
|
||||
MODELS=(
|
||||
"Qwen/Qwen3-0.6B-Instruct"
|
||||
)
|
||||
|
||||
# Find the git repository root directory
|
||||
GIT_ROOT=$(git rev-parse --show-toplevel)
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'kill $(jobs -pr)' SIGINT SIGTERM EXIT
|
||||
|
||||
# Gen ranktable
|
||||
RANKTABLE_PATH=${GIT_ROOT}/examples/disaggregate_prefill_v1/ranktable.json
|
||||
if [ -f "$RANKTABLE_PATH" ]; then
|
||||
rm "$RANKTABLE_PATH"
|
||||
fi
|
||||
cd ${GIT_ROOT}/examples/disaggregate_prefill_v1
|
||||
LOCAL_HOST=`hostname -I|awk -F " " '{print$1}'`
|
||||
bash gen_ranktable.sh --ips $LOCAL_HOST --network-card-name enp189s0f0 --prefill-device-cnt 1 --decode-device-cnt 1
|
||||
cd -
|
||||
export DISAGGREGATED_PREFILL_RANK_TABLE_PATH="$RANKTABLE_PATH"
|
||||
|
||||
# Waits for vLLM to start.
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/health > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
# Function to clean up previous instances
|
||||
cleanup_instances() {
|
||||
echo "Cleaning up any running vLLM instances..."
|
||||
pkill -f "vllm serve" || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# Handle to get model-specific arguments for deepseek
|
||||
get_model_args() {
|
||||
local model_name=$1
|
||||
local extra_args=""
|
||||
|
||||
if [[ "$model_name" == *"deepseek"* ]]; then
|
||||
extra_args="--trust-remote-code"
|
||||
fi
|
||||
|
||||
echo "$extra_args"
|
||||
}
|
||||
|
||||
|
||||
# Function to run tests for a specific model
|
||||
run_tests_for_model() {
|
||||
local model_name=$1
|
||||
echo "================================"
|
||||
echo "Testing model: $model_name"
|
||||
echo "================================"
|
||||
|
||||
# Get model-specific arguments
|
||||
local model_args=$(get_model_args "$model_name")
|
||||
|
||||
# Start prefill instance
|
||||
PREFILL_PORT=8001
|
||||
|
||||
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=0 VLLM_LLMDD_RPC_PORT=5559 vllm serve $model_name \
|
||||
--port $PREFILL_PORT \
|
||||
--seed 1024 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_producer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Start decode instance
|
||||
DECODE_PORT=8002
|
||||
|
||||
# Build the command with or without model-specific args
|
||||
BASE_CMD="ASCEND_RT_VISIBLE_DEVICES=1 VLLM_LLMDD_RPC_PORT=6000 vllm serve $model_name \
|
||||
--port $DECODE_PORT \
|
||||
--seed 1024 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--kv-transfer-config '{\"kv_connector\":\"LLMDataDistCMgrConnector\",\"kv_role\":\"kv_consumer\",\"kv_buffer_device\":\"npu\",\"kv_parallel_size\":\"1\",\"kv_port\":\"20001\",\"engine_id\":\"0\",\"kv_connector_module_path\":\"vllm_ascend.distributed.llmdatadist_c_mgr_connector\"}'"
|
||||
|
||||
if [ -n "$model_args" ]; then
|
||||
FULL_CMD="$BASE_CMD $model_args"
|
||||
else
|
||||
FULL_CMD="$BASE_CMD"
|
||||
fi
|
||||
|
||||
eval "$FULL_CMD &"
|
||||
|
||||
# Wait for all instances to start
|
||||
echo "Waiting for prefill instance on port $PORT to start..."
|
||||
wait_for_server $PREFILL_PORT
|
||||
echo "Waiting for decode instance on port $PORT to start..."
|
||||
wait_for_server $DECODE_PORT
|
||||
|
||||
# Build the command for the proxy server with all the hosts and ports
|
||||
PROXY_PORT=8192
|
||||
PROXY_CMD="python ${GIT_ROOT}/examples/disaggregate_prefill_v1/toy_proxy_server.py --port $PROXY_PORT"
|
||||
PROXY_CMD+=" --prefiller-ports ${PREFILL_PORT}"
|
||||
PROXY_CMD+=" --decoder-ports ${DECODE_PORT}"
|
||||
# Start the proxy server
|
||||
echo "Starting proxy server with command: $PROXY_CMD"
|
||||
$PROXY_CMD &
|
||||
|
||||
# Wait for the proxy to start
|
||||
sleep 5
|
||||
|
||||
# Run lm eval for this model
|
||||
echo "Running tests for $model_name"
|
||||
PREFILL_PORT=$PREFILL_PORT DECODE_PORT=$DECODE_PORT PROXY_PORT=$PROXY_PORT python -m pytest -s -v ${GIT_ROOT}/tests/e2e/pd_disaggreate/test_edge_cases.py
|
||||
|
||||
# Clean up before running next model
|
||||
cleanup_instances
|
||||
sleep 3
|
||||
}
|
||||
|
||||
# Run tests for each model
|
||||
for model in "${MODELS[@]}"; do
|
||||
run_tests_for_model "$model"
|
||||
done
|
||||
|
||||
echo "All tests completed!"
|
||||
81
tests/e2e/pd_disaggreate/test_edge_cases.py
Normal file
81
tests/e2e/pd_disaggreate/test_edge_cases.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This code is from: https://github.com/vllm-project/vllm/blob/main/tests/v1/kv_connector/nixl_integration/test_edge_cases.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
import os
|
||||
|
||||
import openai
|
||||
|
||||
PREFILL_PORT = os.getenv("PREFILL_PORT", None)
|
||||
DECODE_PORT = os.getenv("DECODE_PORT", None)
|
||||
PROXY_PORT = os.getenv("PROXY_PORT", None)
|
||||
|
||||
if PREFILL_PORT is None or DECODE_PORT is None or PROXY_PORT is None:
|
||||
raise ValueError(
|
||||
"Please set the PREFILL_PORT, DECODE_PORT, and PROXY_PORT.")
|
||||
|
||||
LONG_PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result, when working on projects like vLLM we are able to meet many amazing people from various organizations like AMD, Google, NVIDIA, " # noqa: E501
|
||||
PROMPT = "Red Hat is the best company in the world to work for because it works on open source software, which means that all the contributions are delivered to the community. As a result," # noqa: E501
|
||||
SHORT_PROMPT = "Red Hat is "
|
||||
|
||||
|
||||
def test_edge_cases():
|
||||
# Set the OpenAI API key and base URL
|
||||
decode_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{DECODE_PORT}/v1",
|
||||
)
|
||||
prefill_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{PREFILL_PORT}/v1",
|
||||
)
|
||||
proxy_client = openai.OpenAI(
|
||||
api_key="MY_KEY",
|
||||
base_url=f"http://localhost:{PROXY_PORT}/v1",
|
||||
)
|
||||
|
||||
# Get the list of models
|
||||
models = decode_client.models.list()
|
||||
MODEL = models.data[0].id
|
||||
|
||||
# (1) Check that we can handle a very short prompt,
|
||||
# less than the length of the block size.
|
||||
completion = proxy_client.completions.create(model=MODEL,
|
||||
prompt=SHORT_PROMPT,
|
||||
temperature=0)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(model=MODEL,
|
||||
prompt=SHORT_PROMPT,
|
||||
temperature=0)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"SMALL PROMPT: {proxy_response=}")
|
||||
print(f"SMALL PROMPT: {prefill_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
|
||||
# (2) Check that we can handle a full prefix cache
|
||||
# hit on the D worker but not on the P worker.
|
||||
# (2a): prime the D worker.
|
||||
completion = decode_client.completions.create(model=MODEL,
|
||||
prompt=PROMPT,
|
||||
temperature=0)
|
||||
decode_response = completion.choices[0].text
|
||||
# (2b): send via the P/D setup
|
||||
completion = proxy_client.completions.create(model=MODEL,
|
||||
prompt=PROMPT,
|
||||
temperature=0)
|
||||
proxy_response = completion.choices[0].text
|
||||
print(f"FULL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == decode_response
|
||||
|
||||
# (3) Check that we can handle a partial prefix cache
|
||||
# hit on the D worker.
|
||||
completion = proxy_client.completions.create(model=MODEL,
|
||||
prompt=LONG_PROMPT,
|
||||
temperature=0)
|
||||
proxy_response = completion.choices[0].text
|
||||
completion = prefill_client.completions.create(model=MODEL,
|
||||
prompt=LONG_PROMPT,
|
||||
temperature=0)
|
||||
prefill_response = completion.choices[0].text
|
||||
print(f"PARTIAL CACHE HIT: {proxy_response=}")
|
||||
assert proxy_response == prefill_response
|
||||
@@ -251,7 +251,10 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
kv_cache = [k_cache, v_cache]
|
||||
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
||||
@@ -261,7 +264,7 @@ class TestAscendAttentionBackendImpl(TestBase):
|
||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
||||
layer = self.layer
|
||||
layer.quant_method = MagicMock()
|
||||
layer.quant_method.apply.return_value = kv_cache
|
||||
layer.quant_method.apply.return_value = ret_value
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
|
||||
42
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
42
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import \
|
||||
LLMDataDistCMgrConnectorMetadata
|
||||
|
||||
|
||||
def test_basic_inferface():
|
||||
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.requests) == 1
|
||||
assert request_id in kv_connector_metadata.requests
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
163
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
163
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,163 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and blocks should be freed
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_id]
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still works with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||
|
||||
scheduler.add_request(request_remote_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
|
||||
use_eos=True)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||
1))
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_sending = [request_remote.request_id]
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
248
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
248
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,248 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
if vllm_version_is("0.9.1"):
|
||||
assert len(scheduler_output.scheduled_cached_reqs) == 0
|
||||
else:
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert (request.num_computed_tokens == 0)
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.finished_recving = [request_id]
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
if vllm_version_is("0.9.1"):
|
||||
outputs = engine_core_outputs[0].outputs
|
||||
assert len(outputs) == 1
|
||||
output = outputs[0]
|
||||
assert output.finish_reason == FinishReason.STOP
|
||||
assert_scheduler_empty(scheduler)
|
||||
201
tests/ut/kv_connector/utils.py
Normal file
201
tests/ut/kv_connector/utils.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
model: str = "facebook/opt-125m",
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 1024,
|
||||
block_size: int = 128,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LLMDataDistCMgrConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_module_path=
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 128,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
if use_all_1s_for_prompt_tokens:
|
||||
prompt_token_ids = [1] * num_tokens
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
**({
|
||||
"pooling_params": []
|
||||
} if not vllm_version_is("0.9.1") else {}),
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
# Make output data structure.
|
||||
return ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
**({
|
||||
"pooler_output": []
|
||||
} if not vllm_version_is("0.9.1") else {}),
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
)
|
||||
Reference in New Issue
Block a user