[Scheduler][MTP] Add support for speculative decoding in AsecendScheduler. (#943)
This PR adds support for speculative decoding in AsecendScheduler. Also inculde part of support for disaggregated prefill, full support will be merged in follow-up PR. --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
15
.github/workflows/vllm_ascend_test.yaml
vendored
15
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -180,18 +180,20 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
||||||
pytest -sv tests/singlecard/test_scheduler.py
|
|
||||||
# guided decoding doesn't work, fix it later
|
# guided decoding doesn't work, fix it later
|
||||||
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
||||||
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
|
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
|
||||||
pytest -sv tests/singlecard/test_ascend_config.py
|
pytest -sv tests/singlecard/test_ascend_config.py
|
||||||
pytest -sv tests/singlecard/test_camem.py
|
pytest -sv tests/singlecard/test_camem.py
|
||||||
|
# pytest -sv tests/singlecard/core/test_ascend_scheduler.py
|
||||||
|
# pytest -sv tests/singlecard/core/test_ascend_scheduler_e2e.py
|
||||||
pytest -sv tests/singlecard/ \
|
pytest -sv tests/singlecard/ \
|
||||||
--ignore=tests/singlecard/test_offline_inference.py \
|
--ignore=tests/singlecard/test_offline_inference.py \
|
||||||
--ignore=tests/singlecard/test_scheduler.py \
|
|
||||||
--ignore=tests/singlecard/test_guided_decoding.py \
|
--ignore=tests/singlecard/test_guided_decoding.py \
|
||||||
--ignore=tests/singlecard/test_ascend_config.py \
|
--ignore=tests/singlecard/test_ascend_config.py \
|
||||||
--ignore=tests/singlecard/test_camem.py
|
--ignore=tests/singlecard/test_camem.py \
|
||||||
|
--ignore=tests/singlecard/core/test_ascend_scheduler.py \
|
||||||
|
--ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py
|
||||||
else
|
else
|
||||||
pytest -sv tests/multicard/test_ilama_lora_tp2.py
|
pytest -sv tests/multicard/test_ilama_lora_tp2.py
|
||||||
# To avoid oom, we need to run the test in a single process.
|
# To avoid oom, we need to run the test in a single process.
|
||||||
@@ -209,20 +211,21 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
|
||||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
|
||||||
pytest -sv tests/singlecard/test_scheduler.py
|
|
||||||
# guided decoding doesn't work, fix it later
|
# guided decoding doesn't work, fix it later
|
||||||
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
# pytest -sv tests/singlecard/test_guided_decoding.py.py
|
||||||
pytest -sv tests/singlecard/test_camem.py
|
pytest -sv tests/singlecard/test_camem.py
|
||||||
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
|
# test_ascend_config.py should be ran separately because it will regenerate the global config many times.
|
||||||
pytest -sv tests/singlecard/test_ascend_config.py
|
pytest -sv tests/singlecard/test_ascend_config.py
|
||||||
pytest -sv tests/singlecard/test_prompt_embedding.py
|
pytest -sv tests/singlecard/test_prompt_embedding.py
|
||||||
|
pytest -sv tests/singlecard/core/test_ascend_scheduler.py
|
||||||
pytest -sv tests/singlecard/ \
|
pytest -sv tests/singlecard/ \
|
||||||
--ignore=tests/singlecard/test_offline_inference.py \
|
--ignore=tests/singlecard/test_offline_inference.py \
|
||||||
--ignore=tests/singlecard/test_scheduler.py \
|
|
||||||
--ignore=tests/singlecard/test_guided_decoding.py \
|
--ignore=tests/singlecard/test_guided_decoding.py \
|
||||||
--ignore=tests/singlecard/test_camem.py \
|
--ignore=tests/singlecard/test_camem.py \
|
||||||
--ignore=tests/singlecard/test_ascend_config.py \
|
--ignore=tests/singlecard/test_ascend_config.py \
|
||||||
--ignore=tests/singlecard/test_prompt_embedding.py
|
--ignore=tests/singlecard/test_prompt_embedding.py \
|
||||||
|
--ignore=tests/singlecard/core/test_ascend_scheduler.py \
|
||||||
|
--ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py
|
||||||
else
|
else
|
||||||
pytest -sv tests/multicard/test_ilama_lora_tp2.py
|
pytest -sv tests/multicard/test_ilama_lora_tp2.py
|
||||||
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.
|
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.
|
||||||
|
|||||||
0
tests/singlecard/core/__init__.py
Normal file
0
tests/singlecard/core/__init__.py
Normal file
792
tests/singlecard/core/test_ascend_scheduler.py
Normal file
792
tests/singlecard/core/test_ascend_scheduler.py
Normal file
@@ -0,0 +1,792 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
|
KVCacheGroupSpec)
|
||||||
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
|
|
||||||
|
def create_scheduler(
|
||||||
|
model: str = "Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
max_num_seqs: int = 16,
|
||||||
|
max_num_batched_tokens: int = 8192,
|
||||||
|
enable_prefix_caching: Optional[bool] = None,
|
||||||
|
long_prefill_token_threshold: int = 0,
|
||||||
|
disable_chunked_mm_input: bool = False,
|
||||||
|
use_kv_connector: bool = False,
|
||||||
|
num_blocks: int = 10000,
|
||||||
|
block_size: int = 16,
|
||||||
|
max_model_len: Optional[int] = None,
|
||||||
|
num_speculative_tokens: Optional[int] = None,
|
||||||
|
enable_chunked_prefill: bool = False,
|
||||||
|
) -> AscendScheduler:
|
||||||
|
'''Create scheduler under test.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: model under test
|
||||||
|
max_num_seqs: max sequences to schedule
|
||||||
|
max_num_batch_tokens: max num tokens to batch
|
||||||
|
enable_prefix_caching: optionally force APC config
|
||||||
|
(True/False) or use default
|
||||||
|
(None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{class}`Scheduler` instance
|
||||||
|
'''
|
||||||
|
if max_model_len is None:
|
||||||
|
max_model_len = max_num_batched_tokens
|
||||||
|
scheduler_config = SchedulerConfig(
|
||||||
|
max_num_seqs=max_num_seqs,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
|
max_model_len=max_model_len,
|
||||||
|
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||||
|
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||||
|
enable_chunked_prefill=enable_chunked_prefill,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model=model,
|
||||||
|
task="auto",
|
||||||
|
tokenizer=model,
|
||||||
|
tokenizer_mode="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
dtype="float16",
|
||||||
|
seed=42,
|
||||||
|
)
|
||||||
|
# Cache config, optionally force APC
|
||||||
|
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||||
|
'enable_prefix_caching': enable_prefix_caching
|
||||||
|
})
|
||||||
|
cache_config = CacheConfig(
|
||||||
|
block_size=block_size,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
swap_space=0,
|
||||||
|
cache_dtype="auto",
|
||||||
|
**kwargs_cache,
|
||||||
|
)
|
||||||
|
kv_transfer_config = KVTransferConfig(
|
||||||
|
kv_connector="SharedStorageConnector",
|
||||||
|
kv_role="kv_both",
|
||||||
|
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||||
|
) if use_kv_connector else None
|
||||||
|
|
||||||
|
speculative_config: Optional[SpeculativeConfig] = None
|
||||||
|
if num_speculative_tokens is not None:
|
||||||
|
speculative_config = SpeculativeConfig(
|
||||||
|
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
model_config=model_config,
|
||||||
|
cache_config=cache_config,
|
||||||
|
kv_transfer_config=kv_transfer_config,
|
||||||
|
speculative_config=speculative_config,
|
||||||
|
)
|
||||||
|
kv_cache_config = KVCacheConfig(
|
||||||
|
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||||
|
**({
|
||||||
|
"tensors": {}
|
||||||
|
} if vllm_version_is("0.9.0") else {
|
||||||
|
"kv_cache_tensors": []
|
||||||
|
}),
|
||||||
|
kv_cache_groups=[
|
||||||
|
KVCacheGroupSpec(['layer'],
|
||||||
|
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||||
|
False))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
cache_config.num_gpu_blocks = num_blocks
|
||||||
|
return AscendScheduler(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_config=kv_cache_config,
|
||||||
|
log_stats=True,
|
||||||
|
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_requests(num_requests: int,
|
||||||
|
num_tokens: int = 10,
|
||||||
|
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||||
|
max_tokens: int = 16,
|
||||||
|
stop_token_ids: Optional[list[int]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None):
|
||||||
|
sampling_params = SamplingParams(ignore_eos=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
stop_token_ids=stop_token_ids,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
requests = []
|
||||||
|
for i in range(num_requests):
|
||||||
|
if mm_positions is not None:
|
||||||
|
mm_position = mm_positions[i]
|
||||||
|
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||||
|
else:
|
||||||
|
mm_position = None
|
||||||
|
mm_inputs = None
|
||||||
|
request = Request(
|
||||||
|
request_id=f"{i}",
|
||||||
|
prompt_token_ids=[i] * num_tokens,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
multi_modal_inputs=mm_inputs,
|
||||||
|
multi_modal_placeholders=mm_position,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
**({
|
||||||
|
"arrival_time": 0.0
|
||||||
|
} if vllm_version_is("0.9.0") else {}),
|
||||||
|
)
|
||||||
|
requests.append(request)
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
assert request.request_id in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == i + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_request():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_ABORTED)
|
||||||
|
assert request.request_id not in scheduler.requests
|
||||||
|
assert len(scheduler.waiting) == 9 - i
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_unfinished_requests():
|
||||||
|
scheduler = create_scheduler()
|
||||||
|
requests = create_requests(num_requests=10)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.finish_requests(request.request_id,
|
||||||
|
RequestStatus.FINISHED_STOPPED)
|
||||||
|
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||||
|
(None, None),
|
||||||
|
(True, 5),
|
||||||
|
])
|
||||||
|
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||||
|
prompt_logprobs: Optional[int]):
|
||||||
|
'''Test scheduling.
|
||||||
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||||
|
'''
|
||||||
|
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||||
|
requests = create_requests(num_requests=10,
|
||||||
|
prompt_logprobs=prompt_logprobs)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
# Test initial scheduling
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
# Verify all requests are scheduled.
|
||||||
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||||
|
|
||||||
|
# Verify requests moved from waiting to running
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == len(requests)
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
assert scheduler.running[i] == request
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||||
|
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||||
|
"""Test scheduling behavior with concurrent partial requests.
|
||||||
|
|
||||||
|
This test verifies that: there are multiple long prefill requests in the
|
||||||
|
RUNNING state, and we can schedule them together.
|
||||||
|
|
||||||
|
"""
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
model="facebook/opt-125m",
|
||||||
|
max_num_batched_tokens=1024,
|
||||||
|
long_prefill_token_threshold=400,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
)
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=3,
|
||||||
|
num_tokens=800,
|
||||||
|
)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
|
# The first request is scheduled partially - 400.
|
||||||
|
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
|
# The second request is scheduled partially - 400.
|
||||||
|
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
|
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
||||||
|
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
||||||
|
req_to_index = {
|
||||||
|
request.request_id: i
|
||||||
|
for i, request in enumerate(requests)
|
||||||
|
}
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output, model_runner_output)
|
||||||
|
|
||||||
|
# Schedule the next step. All three requests are running.
|
||||||
|
# Processed the remaining prefills of the first and second requests.
|
||||||
|
output1 = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(output1.scheduled_new_reqs) == 0
|
||||||
|
assert len(output1.scheduled_cached_reqs) == 3
|
||||||
|
assert len(output1.finished_req_ids) == 0
|
||||||
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
|
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
||||||
|
|
||||||
|
# Schedule the third step. All three requests are running.
|
||||||
|
# First and second requests are in the decode stage.
|
||||||
|
# All the remaining tokens in the third request are processed.
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[request.request_id for request in requests],
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(output1, model_runner_output)
|
||||||
|
output2 = scheduler.schedule()
|
||||||
|
assert len(scheduler.running) == 3
|
||||||
|
assert len(output2.scheduled_new_reqs) == 0
|
||||||
|
assert len(output2.scheduled_cached_reqs) == 3
|
||||||
|
assert len(output2.finished_req_ids) == 0
|
||||||
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||||
|
assert output2.num_scheduled_tokens[
|
||||||
|
requests[2].request_id] == 800 - 224 - 224
|
||||||
|
|
||||||
|
|
||||||
|
def test_stop_via_update_from_output():
|
||||||
|
"""Test stopping behavior through update_from_output"""
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=1)
|
||||||
|
|
||||||
|
# Test case 1: Stop on EOS token
|
||||||
|
requests = create_requests(num_requests=2, max_tokens=10)
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 1,
|
||||||
|
requests[1].request_id: 2
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=3,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [],
|
||||||
|
requests[1].request_id: [10]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||||
|
[10,
|
||||||
|
11]], # First request hits EOS, second continues
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped, second continues
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
||||||
|
assert list(requests[1].output_token_ids) == [10, 11]
|
||||||
|
|
||||||
|
# Test case 2: Stop on custom stop token
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
|
requests = create_requests(num_requests=2,
|
||||||
|
max_tokens=10,
|
||||||
|
stop_token_ids=[42, 43])
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 3,
|
||||||
|
requests[1].request_id: 2
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=5,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [10, 42],
|
||||||
|
requests[1].request_id: [13]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[10, 42, 12],
|
||||||
|
[13, 14]], # First request hits stop token
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped on custom token
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||||
|
assert requests[0].stop_reason == 42
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [10, 42]
|
||||||
|
assert list(requests[1].output_token_ids) == [13, 14]
|
||||||
|
|
||||||
|
# Test case 3: Stop on max tokens
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
|
requests = create_requests(num_requests=2, max_tokens=2)
|
||||||
|
for req in requests:
|
||||||
|
req.num_computed_tokens = req.num_tokens
|
||||||
|
scheduler.requests[req.request_id] = req
|
||||||
|
scheduler.running.append(req)
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={
|
||||||
|
requests[0].request_id: 3,
|
||||||
|
requests[1].request_id: 1
|
||||||
|
},
|
||||||
|
total_num_scheduled_tokens=4,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [10, 11],
|
||||||
|
requests[1].request_id: []
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in requests],
|
||||||
|
req_id_to_index={req.request_id: i
|
||||||
|
for i, req in enumerate(requests)},
|
||||||
|
sampled_token_ids=[[10, 11, 12],
|
||||||
|
[13]], # First request exceeds max_tokens
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify first request stopped due to length
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert scheduler.running[0].request_id == requests[1].request_id
|
||||||
|
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||||
|
assert requests[0].request_id in scheduler.finished_req_ids
|
||||||
|
assert list(requests[0].output_token_ids) == [10, 11
|
||||||
|
] # Truncated to max_tokens
|
||||||
|
assert list(requests[1].output_token_ids) == [13]
|
||||||
|
|
||||||
|
# Test case 4: Ignore EOS flag
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||||
|
requests = create_requests(num_requests=1, max_tokens=10)
|
||||||
|
requests[0].sampling_params.ignore_eos = True
|
||||||
|
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||||
|
scheduler.requests[requests[0].request_id] = requests[0]
|
||||||
|
scheduler.running.append(requests[0])
|
||||||
|
|
||||||
|
scheduler_output = SchedulerOutput(
|
||||||
|
scheduled_new_reqs=[],
|
||||||
|
scheduled_cached_reqs=[],
|
||||||
|
num_scheduled_tokens={requests[0].request_id: 3},
|
||||||
|
total_num_scheduled_tokens=3,
|
||||||
|
scheduled_encoder_inputs={},
|
||||||
|
scheduled_spec_decode_tokens={
|
||||||
|
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||||
|
},
|
||||||
|
num_common_prefix_blocks=0,
|
||||||
|
finished_req_ids=set(),
|
||||||
|
free_encoder_input_ids=[],
|
||||||
|
structured_output_request_ids={},
|
||||||
|
grammar_bitmask=None)
|
||||||
|
|
||||||
|
model_output = ModelRunnerOutput(
|
||||||
|
req_ids=[requests[0].request_id],
|
||||||
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={})
|
||||||
|
|
||||||
|
scheduler.update_from_output(scheduler_output, model_output)
|
||||||
|
|
||||||
|
# Verify request continues past EOS
|
||||||
|
assert len(scheduler.running) == 1
|
||||||
|
assert not requests[0].is_finished()
|
||||||
|
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||||
|
(None, None),
|
||||||
|
(True, 5),
|
||||||
|
])
|
||||||
|
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||||
|
prompt_logprobs: Optional[int]):
|
||||||
|
scheduler = create_scheduler(
|
||||||
|
max_num_batched_tokens=1024,
|
||||||
|
max_num_seqs=2,
|
||||||
|
enable_prefix_caching=enable_prefix_caching,
|
||||||
|
enable_chunked_prefill=True,
|
||||||
|
)
|
||||||
|
requests = create_requests(
|
||||||
|
num_requests=2,
|
||||||
|
num_tokens=512,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schedule the first request.
|
||||||
|
scheduler.add_request(requests[0])
|
||||||
|
scheduler_output0 = scheduler.schedule()
|
||||||
|
assert len(scheduler_output0.scheduled_new_reqs) == 1
|
||||||
|
assert scheduler_output0.num_scheduled_tokens[
|
||||||
|
requests[0].request_id] == 512
|
||||||
|
|
||||||
|
# The first request is still running, so only schedule the second request.
|
||||||
|
scheduler.add_request(requests[1])
|
||||||
|
scheduler_output1 = scheduler.schedule()
|
||||||
|
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
||||||
|
assert scheduler_output1.num_scheduled_tokens[
|
||||||
|
requests[1].request_id] == 512
|
||||||
|
|
||||||
|
# Model output of the first request.
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[requests[0].request_id],
|
||||||
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
|
sampled_token_ids=[[0]],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
||||||
|
|
||||||
|
# Schedule the next step.
|
||||||
|
# The first request can be scheduled again while the second
|
||||||
|
# request is still running.
|
||||||
|
scheduler_output2 = scheduler.schedule()
|
||||||
|
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
|
|
||||||
|
# Model output of the second request.
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=[requests[1].request_id],
|
||||||
|
req_id_to_index={requests[1].request_id: 0},
|
||||||
|
sampled_token_ids=[[0]],
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||||
|
|
||||||
|
|
||||||
|
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"spec_tokens,output_tokens,expected",
|
||||||
|
[
|
||||||
|
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||||
|
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||||
|
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
|
||||||
|
(2, 3, 3, [2, 1])), # multiple sequences
|
||||||
|
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||||
|
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||||
|
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||||
|
(2, 6, 3, [2, 1, 0])), # multiple mismatches
|
||||||
|
])
|
||||||
|
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||||
|
"""Test scheduling behavior with speculative decoding.
|
||||||
|
|
||||||
|
This test verifies that:
|
||||||
|
1. Speculated tokens get scheduled correctly
|
||||||
|
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||||
|
"""
|
||||||
|
if vllm_version_is("0.9.0"):
|
||||||
|
return
|
||||||
|
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||||
|
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
||||||
|
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||||
|
req_ids = []
|
||||||
|
req_to_index = {}
|
||||||
|
for i, request in enumerate(requests):
|
||||||
|
scheduler.add_request(request)
|
||||||
|
req_ids.append(request.request_id)
|
||||||
|
req_to_index[request.request_id] = i
|
||||||
|
|
||||||
|
# Schedule a decode, which will also draft speculative tokens
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert output.total_num_scheduled_tokens == len(requests)
|
||||||
|
for i in range(len(requests)):
|
||||||
|
req_id = requests[i].request_id
|
||||||
|
assert output.num_scheduled_tokens[req_id] == 1
|
||||||
|
assert req_id not in output.scheduled_spec_decode_tokens
|
||||||
|
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||||
|
spec_token_ids=spec_tokens,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
|
model_runner_output)
|
||||||
|
|
||||||
|
for i in range(len(requests)):
|
||||||
|
running_req = scheduler.running[i]
|
||||||
|
# The prompt token
|
||||||
|
assert running_req.num_computed_tokens == 1
|
||||||
|
# The prompt token and the sampled token
|
||||||
|
assert running_req.num_tokens == 2
|
||||||
|
# The prompt token, the sampled token, and the speculated tokens
|
||||||
|
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||||
|
|
||||||
|
# No draft or accepted tokens counted yet
|
||||||
|
assert not engine_core_outputs or (
|
||||||
|
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
|
||||||
|
|
||||||
|
# Schedule the speculated tokens for validation
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
# The sampled token and speculated tokens
|
||||||
|
assert output.total_num_scheduled_tokens == \
|
||||||
|
len(requests) + sum(len(ids) for ids in spec_tokens)
|
||||||
|
for i in range(len(requests)):
|
||||||
|
req_id = requests[i].request_id
|
||||||
|
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
||||||
|
if spec_tokens[i]:
|
||||||
|
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
|
||||||
|
len(spec_tokens[i])
|
||||||
|
else:
|
||||||
|
assert req_id not in output.scheduled_spec_decode_tokens
|
||||||
|
|
||||||
|
model_runner_output = ModelRunnerOutput(
|
||||||
|
req_ids=req_ids,
|
||||||
|
req_id_to_index=req_to_index,
|
||||||
|
sampled_token_ids=output_tokens,
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
|
model_runner_output)
|
||||||
|
|
||||||
|
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||||
|
if engine_core_outputs else None
|
||||||
|
if expected[0] == 0:
|
||||||
|
assert scheduler_stats.spec_decoding_stats is None # type: ignore
|
||||||
|
else:
|
||||||
|
assert scheduler_stats.spec_decoding_stats is not None # type: ignore
|
||||||
|
stats = scheduler_stats.spec_decoding_stats # type: ignore
|
||||||
|
assert stats.num_drafts == expected[0]
|
||||||
|
assert stats.num_draft_tokens == expected[1]
|
||||||
|
assert stats.num_accepted_tokens == expected[2]
|
||||||
|
assert stats.num_accepted_tokens_per_pos == expected[3]
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_right_scheduler_output(
|
||||||
|
output: SchedulerOutput,
|
||||||
|
num_requests: int,
|
||||||
|
expected_num_scheduled_tokens: int,
|
||||||
|
):
|
||||||
|
"""Check if SchedulerOutput is correct after remote KV cache hit."""
|
||||||
|
|
||||||
|
# We should inject the kv_connector_metadata.
|
||||||
|
assert len(output.kv_connector_metadata.requests) == num_requests
|
||||||
|
|
||||||
|
# Only num_tokens - matched_num_new_tokens should be scheduled.
|
||||||
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
||||||
|
assert num_scheduled_tokens == expected_num_scheduled_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_right_kv_cache_manager(
|
||||||
|
scheduler: AscendScheduler,
|
||||||
|
req_ids: list[str],
|
||||||
|
num_tokens: int,
|
||||||
|
block_size: int,
|
||||||
|
num_requests: int,
|
||||||
|
num_total_blocks: int,
|
||||||
|
):
|
||||||
|
"""Check whether KVCacheManager is correct after allocate."""
|
||||||
|
|
||||||
|
# Make sure the request stats are right.
|
||||||
|
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
|
||||||
|
for req_id in req_ids:
|
||||||
|
blocks = (scheduler.kv_cache_manager.coordinator.
|
||||||
|
single_type_managers[0].req_to_blocks[req_id])
|
||||||
|
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
|
||||||
|
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||||
|
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
|
||||||
|
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
|
||||||
|
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
|
||||||
|
|
||||||
|
# Make sure we actually touched all the blocks.
|
||||||
|
BLOCKS_PER_REQ = num_tokens / block_size
|
||||||
|
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
|
||||||
|
num_total_blocks - num_requests * BLOCKS_PER_REQ)
|
||||||
|
|
||||||
|
|
||||||
|
def _step_until_done(
|
||||||
|
scheduler: AscendScheduler,
|
||||||
|
output: SchedulerOutput,
|
||||||
|
model_runner_output: ModelRunnerOutput,
|
||||||
|
):
|
||||||
|
"""Loop over schedule(), update_from_output() until finished."""
|
||||||
|
|
||||||
|
all_finished = False
|
||||||
|
_ = scheduler.update_from_output(output, model_runner_output)
|
||||||
|
while not all_finished:
|
||||||
|
# Schedule + a few iterations until stopping.
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(scheduler.running)
|
||||||
|
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
|
||||||
|
# We should be in the decode phase now.
|
||||||
|
assert num_scheduled_tokens == 1
|
||||||
|
assert len(output.kv_connector_metadata.requests) == 0
|
||||||
|
ecos = scheduler.update_from_output(output, model_runner_output)[0]
|
||||||
|
all_done = True
|
||||||
|
for eco in ecos.outputs:
|
||||||
|
if eco.finish_reason is None:
|
||||||
|
all_done = False
|
||||||
|
all_finished = all_done
|
||||||
|
|
||||||
|
|
||||||
|
def make_output(scheduler: AscendScheduler):
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=[req.request_id for req in scheduler.running],
|
||||||
|
req_id_to_index={
|
||||||
|
req.request_id: i
|
||||||
|
for i, req in enumerate(scheduler.running)
|
||||||
|
},
|
||||||
|
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||||
|
spec_token_ids=None,
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_scheduler_empty(scheduler: AscendScheduler):
|
||||||
|
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||||
|
# Scheduler Metadata.
|
||||||
|
assert len(scheduler.requests) == 0
|
||||||
|
assert len(scheduler.waiting) == 0
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
|
assert len(scheduler._cached_reqs_data) == 0
|
||||||
|
|
||||||
|
# EncoderCacheManager.
|
||||||
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||||
|
|
||||||
|
# KVCache Manager.
|
||||||
|
if not vllm_version_is("0.9.0"):
|
||||||
|
assert len(scheduler.kv_cache_manager.coordinator.
|
||||||
|
single_type_managers[0].req_to_blocks) == 0
|
||||||
|
assert len(scheduler.kv_cache_manager.coordinator.
|
||||||
|
single_type_managers[0].num_cached_block) == 0
|
||||||
|
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||||
|
num_free_blocks = (
|
||||||
|
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||||
|
assert num_free_blocks == (
|
||||||
|
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||||
|
|
||||||
|
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||||
|
# value, etc will remain since we lazily evict for prefix cache.
|
||||||
|
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||||
|
assert block.ref_cnt == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_memory_leak():
|
||||||
|
"""Test that we do not have a memory leak."""
|
||||||
|
|
||||||
|
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||||
|
|
||||||
|
NUM_REQUESTS = 5
|
||||||
|
NUM_TOKENS = 10
|
||||||
|
MAX_TOKENS = 10
|
||||||
|
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||||
|
num_tokens=NUM_TOKENS,
|
||||||
|
max_tokens=MAX_TOKENS)
|
||||||
|
|
||||||
|
# Add each request.
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Iterate until done.
|
||||||
|
while True:
|
||||||
|
scheduler_output = scheduler.schedule()
|
||||||
|
if len(scheduler.running) == 0:
|
||||||
|
break
|
||||||
|
model_runner_output = make_output(scheduler)
|
||||||
|
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||||
|
|
||||||
|
# Confirm no memory leak.
|
||||||
|
assert_scheduler_empty(scheduler)
|
||||||
40
tests/singlecard/core/test_ascend_scheduler_e2e.py
Normal file
40
tests/singlecard/core/test_ascend_scheduler_e2e.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from vllm import LLM
|
||||||
|
|
||||||
|
if os.getenv("VLLM_USE_V1", "0") != "1":
|
||||||
|
pytest.skip("Test package requires V1", allow_module_level=True)
|
||||||
|
|
||||||
|
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||||
|
PROMPT = "Hello my name is Robert and I"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def model() -> LLM:
|
||||||
|
return LLM(
|
||||||
|
MODEL,
|
||||||
|
enforce_eager=True,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
max_num_batched_tokens=200,
|
||||||
|
max_num_seqs=3,
|
||||||
|
additional_config={"ascend_scheduler_config": {
|
||||||
|
"enabled": True,
|
||||||
|
}})
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_partial_prefill(model):
|
||||||
|
outputs = model.generate([PROMPT] * 3)
|
||||||
|
assert len(outputs) == 3
|
||||||
|
for output in outputs:
|
||||||
|
assert len(output.outputs) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_stats_is_recorded(model):
|
||||||
|
# 17 tokens will make sure first 16 tokens are cached in a block
|
||||||
|
input_tokens = {"prompt_token_ids": [101] * 129}
|
||||||
|
_ = model.generate([input_tokens])
|
||||||
|
outputs = model.generate([input_tokens])
|
||||||
|
assert outputs[0].num_cached_tokens == 128
|
||||||
@@ -14,16 +14,19 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Iterable, Union
|
from typing import Iterable, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_events import KVEventBatch
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||||
from vllm.utils import cdiv
|
from vllm.utils import cdiv
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.engine import EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
@@ -49,11 +52,6 @@ class AscendScheduler(Scheduler):
|
|||||||
self.scheduled_req_ids: set[str] = set()
|
self.scheduled_req_ids: set[str] = set()
|
||||||
self.running: list[Request] = []
|
self.running: list[Request] = []
|
||||||
|
|
||||||
if self.vllm_config.kv_transfer_config is not None and \
|
|
||||||
self.vllm_config.kv_transfer_config.is_kv_consumer:
|
|
||||||
raise ValueError(
|
|
||||||
"AscendScheduler cannot be used for decode nodes. ")
|
|
||||||
|
|
||||||
def schedule(self) -> SchedulerOutput:
|
def schedule(self) -> SchedulerOutput:
|
||||||
if self.scheduler_config.chunked_prefill_enabled:
|
if self.scheduler_config.chunked_prefill_enabled:
|
||||||
return super().schedule()
|
return super().schedule()
|
||||||
@@ -68,6 +66,9 @@ class AscendScheduler(Scheduler):
|
|||||||
# Spec decode-related.
|
# Spec decode-related.
|
||||||
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
|
||||||
|
|
||||||
|
# For logging.
|
||||||
|
scheduled_timestamp = time.monotonic()
|
||||||
|
|
||||||
# Record scheduled LoRA requests.
|
# Record scheduled LoRA requests.
|
||||||
scheduled_loras: set[int] = set()
|
scheduled_loras: set[int] = set()
|
||||||
|
|
||||||
@@ -86,6 +87,18 @@ class AscendScheduler(Scheduler):
|
|||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
skipped_waiting_requests.appendleft(request)
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
|
||||||
|
num_prealloc_computed_tokens = 0
|
||||||
|
# P/D: skip request if still waiting for remote kvs.
|
||||||
|
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
|
||||||
|
is_ready = self._update_waiting_for_remote_kv(request)
|
||||||
|
if is_ready:
|
||||||
|
request.status = RequestStatus.WAITING
|
||||||
|
num_prealloc_computed_tokens = (
|
||||||
|
request.num_computed_tokens)
|
||||||
|
else:
|
||||||
|
skip_cur_request()
|
||||||
|
continue
|
||||||
|
|
||||||
# Check that adding the request still respects the max_loras
|
# Check that adding the request still respects the max_loras
|
||||||
# constraint.
|
# constraint.
|
||||||
if (self.lora_config and request.lora_request and
|
if (self.lora_config and request.lora_request and
|
||||||
@@ -95,15 +108,47 @@ class AscendScheduler(Scheduler):
|
|||||||
skip_cur_request()
|
skip_cur_request()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
num_external_computed_tokens = 0
|
||||||
|
load_kv_async = False
|
||||||
|
|
||||||
|
# Get already-cached tokens.
|
||||||
|
if num_prealloc_computed_tokens == 0:
|
||||||
|
new_computed_blocks, num_native_computed_tokens = \
|
||||||
|
self.kv_cache_manager.get_computed_blocks(
|
||||||
|
request)
|
||||||
|
|
||||||
|
# Get externally-cached tokens if using a KVConnector.
|
||||||
|
if self.connector is not None:
|
||||||
|
num_external_computed_tokens, load_kv_async = (
|
||||||
|
self.connector.get_num_new_matched_tokens(
|
||||||
|
request, num_native_computed_tokens))
|
||||||
|
|
||||||
|
# Total computed tokens (local + external).
|
||||||
|
num_computed_tokens = (num_native_computed_tokens +
|
||||||
|
num_external_computed_tokens)
|
||||||
|
else:
|
||||||
|
# P/D: skip checking prefix cache if loaded from remote kvs.
|
||||||
|
new_computed_blocks = KVCacheBlocks.create_empty()
|
||||||
|
num_native_computed_tokens = 0
|
||||||
|
|
||||||
|
# Total computed tokens (allocated in prior step).
|
||||||
|
num_computed_tokens = num_prealloc_computed_tokens
|
||||||
|
|
||||||
|
# P/D: loading remote KV, do not allocate for new work.
|
||||||
|
if load_kv_async:
|
||||||
|
assert num_external_computed_tokens > 0
|
||||||
|
num_new_tokens = 0
|
||||||
|
blocks = None
|
||||||
|
# Number of tokens to be scheduled.
|
||||||
|
else:
|
||||||
prompt_limit = self._get_prompt_limit(request)
|
prompt_limit = self._get_prompt_limit(request)
|
||||||
# Get already-cached tokens.
|
# Get already-cached tokens.
|
||||||
computed_blocks, num_computed_tokens = (
|
computed_blocks, num_computed_tokens = (
|
||||||
self.kv_cache_manager.get_computed_blocks(request))
|
self.kv_cache_manager.get_computed_blocks(request))
|
||||||
|
# We use `request.num_tokens` instead of
|
||||||
|
# `request.num_prompt_tokens` to consider the resumed
|
||||||
|
# requests, which have output tokens.
|
||||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
|
||||||
num_new_tokens):
|
|
||||||
num_new_tokens = (
|
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
|
||||||
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
|
||||||
self.block_size)
|
self.block_size)
|
||||||
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
|
||||||
@@ -117,7 +162,8 @@ class AscendScheduler(Scheduler):
|
|||||||
prompt_limit,
|
prompt_limit,
|
||||||
)
|
)
|
||||||
request.status = RequestStatus.FINISHED_IGNORED
|
request.status = RequestStatus.FINISHED_IGNORED
|
||||||
self.finished_req_ids.add(request.request_id) # type: ignore
|
self.finished_req_ids.add( # type: ignore
|
||||||
|
request.request_id) # type: ignore
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -125,9 +171,9 @@ class AscendScheduler(Scheduler):
|
|||||||
# Scheduling would exceed token_budget, skip.
|
# Scheduling would exceed token_budget, skip.
|
||||||
skip_cur_request()
|
skip_cur_request()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assert num_new_tokens > 0
|
assert num_new_tokens > 0
|
||||||
blocks = computed_blocks.blocks[0]
|
blocks = computed_blocks.blocks[0]
|
||||||
|
|
||||||
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
watermark = getattr(self.scheduler_config, "watermark", 0.01)
|
||||||
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
if not self._check_watermark_for_prefill(request, num_new_tokens,
|
||||||
blocks, watermark):
|
blocks, watermark):
|
||||||
@@ -136,13 +182,38 @@ class AscendScheduler(Scheduler):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens, new_computed_blocks=computed_blocks)
|
request,
|
||||||
|
num_new_tokens + num_external_computed_tokens,
|
||||||
|
num_native_computed_tokens,
|
||||||
|
new_computed_blocks=computed_blocks,
|
||||||
|
num_lookahead_tokens=self.num_lookahead_tokens,
|
||||||
|
delay_cache_blocks=load_kv_async)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
break
|
break
|
||||||
|
|
||||||
|
# KVConnector: update internal state after allocation.
|
||||||
|
# This information is used to determine if a load is
|
||||||
|
# needed for this request.
|
||||||
|
if num_external_computed_tokens:
|
||||||
|
assert self.connector is not None
|
||||||
|
self.connector.update_state_after_alloc(
|
||||||
|
request,
|
||||||
|
new_computed_blocks + new_blocks,
|
||||||
|
num_external_computed_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
self.waiting.popleft()
|
self.waiting.popleft()
|
||||||
|
if load_kv_async:
|
||||||
|
# If loading async, allocate memory and put request
|
||||||
|
# into the WAITING_FOR_REMOTE_KV state.
|
||||||
|
skipped_waiting_requests.appendleft(request)
|
||||||
|
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
|
||||||
|
continue
|
||||||
self.running.append(request)
|
self.running.append(request)
|
||||||
|
if self.log_stats:
|
||||||
|
request.record_event(EngineCoreEventType.SCHEDULED,
|
||||||
|
scheduled_timestamp)
|
||||||
self.scheduled_req_ids.add(request.request_id)
|
self.scheduled_req_ids.add(request.request_id)
|
||||||
# Check request status.
|
# Check request status.
|
||||||
if request.status == RequestStatus.WAITING:
|
if request.status == RequestStatus.WAITING:
|
||||||
@@ -161,6 +232,9 @@ class AscendScheduler(Scheduler):
|
|||||||
token_budget -= num_new_tokens
|
token_budget -= num_new_tokens
|
||||||
request.status = RequestStatus.RUNNING
|
request.status = RequestStatus.RUNNING
|
||||||
request.num_computed_tokens = num_computed_tokens
|
request.num_computed_tokens = num_computed_tokens
|
||||||
|
# Count the number of prifix cached tokens.
|
||||||
|
if request.num_cached_tokens < 0:
|
||||||
|
request.num_cached_tokens = num_computed_tokens
|
||||||
|
|
||||||
# Put back any skipped requests at the head of the waiting queue
|
# Put back any skipped requests at the head of the waiting queue
|
||||||
if skipped_waiting_requests:
|
if skipped_waiting_requests:
|
||||||
@@ -179,16 +253,45 @@ class AscendScheduler(Scheduler):
|
|||||||
|
|
||||||
num_new_tokens = (request.num_tokens_with_spec -
|
num_new_tokens = (request.num_tokens_with_spec -
|
||||||
request.num_computed_tokens)
|
request.num_computed_tokens)
|
||||||
if (0 < self.scheduler_config.long_prefill_token_threshold <
|
assert (request.num_tokens - request.num_computed_tokens) == 1
|
||||||
num_new_tokens):
|
|
||||||
num_new_tokens = (
|
|
||||||
self.scheduler_config.long_prefill_token_threshold)
|
|
||||||
num_new_tokens = min(num_new_tokens, token_budget)
|
num_new_tokens = min(num_new_tokens, token_budget)
|
||||||
assert num_new_tokens == 1
|
# Make sure the input position does not exceed the max model len.
|
||||||
|
# This is necessary when using spec decoding.
|
||||||
|
num_new_tokens = min(
|
||||||
|
num_new_tokens,
|
||||||
|
self.max_model_len - request.num_computed_tokens)
|
||||||
|
# Check that adding the request still respects the max_loras
|
||||||
|
# constraint.
|
||||||
|
if self.lora_config and request.lora_request and (
|
||||||
|
len(scheduled_loras) == self.lora_config.max_loras
|
||||||
|
and request.lora_request.lora_int_id
|
||||||
|
not in scheduled_loras):
|
||||||
|
# Scheduling would exceed max_loras, skip.
|
||||||
|
num_new_tokens = 0
|
||||||
|
|
||||||
|
if num_new_tokens == 0:
|
||||||
|
# The request cannot be scheduled because one of the following
|
||||||
|
# reason:
|
||||||
|
# 1. No new tokens to schedule. This may happen when PP>1 and
|
||||||
|
# we have already scheduled all prompt tokens but they are
|
||||||
|
# not finished yet.
|
||||||
|
# 2. Adding the request exceeds the max_loras constraint.
|
||||||
|
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
|
||||||
|
# we do not strictly follow the FCFS scheduling policy and
|
||||||
|
# allow the lower-priority requests to be scheduled.
|
||||||
|
req_index += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_draft_tokens = max(
|
||||||
|
num_new_tokens + request.num_computed_tokens -
|
||||||
|
request.num_tokens, 0)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||||
request, num_new_tokens)
|
request,
|
||||||
|
num_new_tokens,
|
||||||
|
num_draft_tokens=num_draft_tokens,
|
||||||
|
num_lookahead_tokens=self.num_lookahead_tokens)
|
||||||
if new_blocks is None:
|
if new_blocks is None:
|
||||||
# The request cannot be scheduled.
|
# The request cannot be scheduled.
|
||||||
# Preempt the lowest-priority request.
|
# Preempt the lowest-priority request.
|
||||||
@@ -196,6 +299,10 @@ class AscendScheduler(Scheduler):
|
|||||||
self.kv_cache_manager.free(preempted_req)
|
self.kv_cache_manager.free(preempted_req)
|
||||||
preempted_req.status = RequestStatus.PREEMPTED
|
preempted_req.status = RequestStatus.PREEMPTED
|
||||||
preempted_req.num_computed_tokens = 0
|
preempted_req.num_computed_tokens = 0
|
||||||
|
if self.log_stats:
|
||||||
|
preempted_req.record_event(
|
||||||
|
EngineCoreEventType.PREEMPTED,
|
||||||
|
scheduled_timestamp)
|
||||||
self.waiting.appendleft(preempted_req)
|
self.waiting.appendleft(preempted_req)
|
||||||
preempted_reqs.append(preempted_req)
|
preempted_reqs.append(preempted_req)
|
||||||
if preempted_req == request:
|
if preempted_req == request:
|
||||||
@@ -230,6 +337,10 @@ class AscendScheduler(Scheduler):
|
|||||||
scheduled_spec_decode_tokens[request.request_id] = (
|
scheduled_spec_decode_tokens[request.request_id] = (
|
||||||
request.spec_token_ids)
|
request.spec_token_ids)
|
||||||
|
|
||||||
|
# Record scheduled LoRA requests.
|
||||||
|
if self.lora_config and request.lora_request:
|
||||||
|
scheduled_loras.add(request.lora_request.lora_int_id)
|
||||||
|
|
||||||
# Check if the scheduling constraints are satisfied.
|
# Check if the scheduling constraints are satisfied.
|
||||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||||
@@ -297,6 +408,11 @@ class AscendScheduler(Scheduler):
|
|||||||
meta = self.connector.build_connector_meta(scheduler_output)
|
meta = self.connector.build_connector_meta(scheduler_output)
|
||||||
scheduler_output.kv_connector_metadata = meta
|
scheduler_output.kv_connector_metadata = meta
|
||||||
|
|
||||||
|
events = self.kv_cache_manager.take_events()
|
||||||
|
if events:
|
||||||
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
|
self.kv_event_publisher.publish(batch)
|
||||||
|
|
||||||
# Advance the number of computed tokens for the request AFTER
|
# Advance the number of computed tokens for the request AFTER
|
||||||
# the request is scheduled.
|
# the request is scheduled.
|
||||||
# 1. The scheduler_output of the current step has to include the
|
# 1. The scheduler_output of the current step has to include the
|
||||||
@@ -388,6 +504,7 @@ class AscendScheduler(Scheduler):
|
|||||||
if num_tokens_scheduled == 0:
|
if num_tokens_scheduled == 0:
|
||||||
# The request was not scheduled in this step.
|
# The request was not scheduled in this step.
|
||||||
continue
|
continue
|
||||||
|
if req_id in self.scheduled_req_ids:
|
||||||
self.scheduled_req_ids.remove(req_id)
|
self.scheduled_req_ids.remove(req_id)
|
||||||
|
|
||||||
return super().update_from_output(scheduler_output,
|
return super().update_from_output(scheduler_output,
|
||||||
|
|||||||
Reference in New Issue
Block a user