[Scheduler][MTP] Add support for speculative decoding in AsecendScheduler. (#943)

This PR adds support for speculative decoding in AsecendScheduler.
Also inculde part of support for disaggregated prefill, full support
will be merged in follow-up PR.

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-06-11 20:55:44 +08:00
committed by GitHub
parent 4f5964420e
commit 3393d53b36
5 changed files with 1001 additions and 49 deletions

View File

@@ -180,18 +180,20 @@ jobs:
run: | run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/singlecard/test_scheduler.py
# guided decoding doesn't work, fix it later # guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py # pytest -sv tests/singlecard/test_guided_decoding.py.py
# test_ascend_config.py should be ran separately because it will regenerate the global config many times. # test_ascend_config.py should be ran separately because it will regenerate the global config many times.
pytest -sv tests/singlecard/test_ascend_config.py pytest -sv tests/singlecard/test_ascend_config.py
pytest -sv tests/singlecard/test_camem.py pytest -sv tests/singlecard/test_camem.py
# pytest -sv tests/singlecard/core/test_ascend_scheduler.py
# pytest -sv tests/singlecard/core/test_ascend_scheduler_e2e.py
pytest -sv tests/singlecard/ \ pytest -sv tests/singlecard/ \
--ignore=tests/singlecard/test_offline_inference.py \ --ignore=tests/singlecard/test_offline_inference.py \
--ignore=tests/singlecard/test_scheduler.py \
--ignore=tests/singlecard/test_guided_decoding.py \ --ignore=tests/singlecard/test_guided_decoding.py \
--ignore=tests/singlecard/test_ascend_config.py \ --ignore=tests/singlecard/test_ascend_config.py \
--ignore=tests/singlecard/test_camem.py --ignore=tests/singlecard/test_camem.py \
--ignore=tests/singlecard/core/test_ascend_scheduler.py \
--ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py
else else
pytest -sv tests/multicard/test_ilama_lora_tp2.py pytest -sv tests/multicard/test_ilama_lora_tp2.py
# To avoid oom, we need to run the test in a single process. # To avoid oom, we need to run the test in a single process.
@@ -209,20 +211,21 @@ jobs:
run: | run: |
if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then if [[ "${{ matrix.os }}" == "linux-arm64-npu-1" ]]; then
VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py VLLM_USE_MODELSCOPE=True pytest -sv tests/singlecard/test_offline_inference.py
pytest -sv tests/singlecard/test_scheduler.py
# guided decoding doesn't work, fix it later # guided decoding doesn't work, fix it later
# pytest -sv tests/singlecard/test_guided_decoding.py.py # pytest -sv tests/singlecard/test_guided_decoding.py.py
pytest -sv tests/singlecard/test_camem.py pytest -sv tests/singlecard/test_camem.py
# test_ascend_config.py should be ran separately because it will regenerate the global config many times. # test_ascend_config.py should be ran separately because it will regenerate the global config many times.
pytest -sv tests/singlecard/test_ascend_config.py pytest -sv tests/singlecard/test_ascend_config.py
pytest -sv tests/singlecard/test_prompt_embedding.py pytest -sv tests/singlecard/test_prompt_embedding.py
pytest -sv tests/singlecard/core/test_ascend_scheduler.py
pytest -sv tests/singlecard/ \ pytest -sv tests/singlecard/ \
--ignore=tests/singlecard/test_offline_inference.py \ --ignore=tests/singlecard/test_offline_inference.py \
--ignore=tests/singlecard/test_scheduler.py \
--ignore=tests/singlecard/test_guided_decoding.py \ --ignore=tests/singlecard/test_guided_decoding.py \
--ignore=tests/singlecard/test_camem.py \ --ignore=tests/singlecard/test_camem.py \
--ignore=tests/singlecard/test_ascend_config.py \ --ignore=tests/singlecard/test_ascend_config.py \
--ignore=tests/singlecard/test_prompt_embedding.py --ignore=tests/singlecard/test_prompt_embedding.py \
--ignore=tests/singlecard/core/test_ascend_scheduler.py \
--ignore=tests/singlecard/core/test_ascend_scheduler_e2e.py
else else
pytest -sv tests/multicard/test_ilama_lora_tp2.py pytest -sv tests/multicard/test_ilama_lora_tp2.py
# Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error. # Fixme: run VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py will raise error.

View File

View File

@@ -0,0 +1,792 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
def create_scheduler(
model: str = "Qwen/Qwen2.5-0.5B-Instruct",
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
enable_chunked_prefill: bool = False,
) -> AscendScheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=enable_chunked_prefill,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
**({
"tensors": {}
} if vllm_version_is("0.9.0") else {
"kv_cache_tensors": []
}),
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return AscendScheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
**({
"arrival_time": 0.0
} if vllm_version_is("0.9.0") else {}),
)
requests.append(request)
return requests
def test_add_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
assert request.request_id in scheduler.requests
assert len(scheduler.waiting) == i + 1
def test_finish_request():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
assert request.request_id not in scheduler.requests
assert len(scheduler.waiting) == 9 - i
def test_get_num_unfinished_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the
RUNNING state, and we can schedule them together.
"""
scheduler = create_scheduler(
model="facebook/opt-125m",
max_num_batched_tokens=1024,
long_prefill_token_threshold=400,
enable_prefix_caching=enable_prefix_caching,
enable_chunked_prefill=True,
)
requests = create_requests(
num_requests=3,
num_tokens=800,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[0].request_id] == 400
# The second request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[1].request_id] == 400
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
assert output.num_scheduled_tokens[requests[2].request_id] == 224
req_to_index = {
request.request_id: i
for i, request in enumerate(requests)
}
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running.
# Processed the remaining prefills of the first and second requests.
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
assert output2.num_scheduled_tokens[
requests[2].request_id] == 800 - 224 - 224
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={req.request_id: i
for i, req in enumerate(requests)},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={})
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
enable_chunked_prefill=True,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.scheduled_new_reqs) == 1
assert scheduler_output0.num_scheduled_tokens[
requests[0].request_id] == 512
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.scheduled_new_reqs) == 1
assert scheduler_output1.num_scheduled_tokens[
requests[1].request_id] == 512
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output0, model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2 = scheduler.schedule()
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
scheduler.update_from_output(scheduler_output1, model_runner_output)
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
(2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0])), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
if vllm_version_is("0.9.0"):
return
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i])
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None # type: ignore
else:
assert scheduler_stats.spec_decoding_stats is not None # type: ignore
stats = scheduler_stats.spec_decoding_stats # type: ignore
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def _assert_right_scheduler_output(
output: SchedulerOutput,
num_requests: int,
expected_num_scheduled_tokens: int,
):
"""Check if SchedulerOutput is correct after remote KV cache hit."""
# We should inject the kv_connector_metadata.
assert len(output.kv_connector_metadata.requests) == num_requests
# Only num_tokens - matched_num_new_tokens should be scheduled.
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
assert num_scheduled_tokens == expected_num_scheduled_tokens
def _assert_right_kv_cache_manager(
scheduler: AscendScheduler,
req_ids: list[str],
num_tokens: int,
block_size: int,
num_requests: int,
num_total_blocks: int,
):
"""Check whether KVCacheManager is correct after allocate."""
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS = num_tokens // block_size
for req_id in req_ids:
blocks = (scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[req_id])
hashes = scheduler.kv_cache_manager.req_to_block_hashes[req_id]
assert (scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block[req_id] == EXPECTED_TOTAL_BLOCKS)
assert len(blocks) == EXPECTED_TOTAL_BLOCKS
assert len(hashes) == EXPECTED_TOTAL_BLOCKS
# Make sure we actually touched all the blocks.
BLOCKS_PER_REQ = num_tokens / block_size
assert (scheduler.kv_cache_manager.block_pool.get_num_free_blocks() ==
num_total_blocks - num_requests * BLOCKS_PER_REQ)
def _step_until_done(
scheduler: AscendScheduler,
output: SchedulerOutput,
model_runner_output: ModelRunnerOutput,
):
"""Loop over schedule(), update_from_output() until finished."""
all_finished = False
_ = scheduler.update_from_output(output, model_runner_output)
while not all_finished:
# Schedule + a few iterations until stopping.
output = scheduler.schedule()
assert len(scheduler.running)
for _, num_scheduled_tokens in output.num_scheduled_tokens.items():
# We should be in the decode phase now.
assert num_scheduled_tokens == 1
assert len(output.kv_connector_metadata.requests) == 0
ecos = scheduler.update_from_output(output, model_runner_output)[0]
all_done = True
for eco in ecos.outputs:
if eco.finish_reason is None:
all_done = False
all_finished = all_done
def make_output(scheduler: AscendScheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={
req.request_id: i
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
def assert_scheduler_empty(scheduler: AscendScheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
if not vllm_version_is("0.9.0"):
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].num_cached_block) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def test_memory_leak():
"""Test that we do not have a memory leak."""
scheduler = create_scheduler(enable_prefix_caching=True)
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import pytest
from vllm import LLM
if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True)
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
PROMPT = "Hello my name is Robert and I"
@pytest.fixture(scope="module")
def model() -> LLM:
return LLM(
MODEL,
enforce_eager=True,
enable_prefix_caching=True,
max_num_batched_tokens=200,
max_num_seqs=3,
additional_config={"ascend_scheduler_config": {
"enabled": True,
}})
def test_concurrent_partial_prefill(model):
outputs = model.generate([PROMPT] * 3)
assert len(outputs) == 3
for output in outputs:
assert len(output.outputs) == 1
def test_prefix_cache_stats_is_recorded(model):
# 17 tokens will make sure first 16 tokens are cached in a block
input_tokens = {"prompt_token_ids": [101] * 129}
_ = model.generate([input_tokens])
outputs = model.generate([input_tokens])
assert outputs[0].num_cached_tokens == 128

View File

@@ -14,16 +14,19 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
# #
import time
from collections import deque from collections import deque
from typing import Iterable, Union from typing import Iterable, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed.kv_events import KVEventBatch
from vllm.logger import logger from vllm.logger import logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.utils import cdiv from vllm.utils import cdiv
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput from vllm.v1.core.sched.output import NewRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine import EngineCoreEventType, EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@@ -49,11 +52,6 @@ class AscendScheduler(Scheduler):
self.scheduled_req_ids: set[str] = set() self.scheduled_req_ids: set[str] = set()
self.running: list[Request] = [] self.running: list[Request] = []
if self.vllm_config.kv_transfer_config is not None and \
self.vllm_config.kv_transfer_config.is_kv_consumer:
raise ValueError(
"AscendScheduler cannot be used for decode nodes. ")
def schedule(self) -> SchedulerOutput: def schedule(self) -> SchedulerOutput:
if self.scheduler_config.chunked_prefill_enabled: if self.scheduler_config.chunked_prefill_enabled:
return super().schedule() return super().schedule()
@@ -68,6 +66,9 @@ class AscendScheduler(Scheduler):
# Spec decode-related. # Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {} scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# For logging.
scheduled_timestamp = time.monotonic()
# Record scheduled LoRA requests. # Record scheduled LoRA requests.
scheduled_loras: set[int] = set() scheduled_loras: set[int] = set()
@@ -86,6 +87,18 @@ class AscendScheduler(Scheduler):
self.waiting.popleft() self.waiting.popleft()
skipped_waiting_requests.appendleft(request) skipped_waiting_requests.appendleft(request)
num_prealloc_computed_tokens = 0
# P/D: skip request if still waiting for remote kvs.
if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS:
is_ready = self._update_waiting_for_remote_kv(request)
if is_ready:
request.status = RequestStatus.WAITING
num_prealloc_computed_tokens = (
request.num_computed_tokens)
else:
skip_cur_request()
continue
# Check that adding the request still respects the max_loras # Check that adding the request still respects the max_loras
# constraint. # constraint.
if (self.lora_config and request.lora_request and if (self.lora_config and request.lora_request and
@@ -95,15 +108,47 @@ class AscendScheduler(Scheduler):
skip_cur_request() skip_cur_request()
continue continue
num_external_computed_tokens = 0
load_kv_async = False
# Get already-cached tokens.
if num_prealloc_computed_tokens == 0:
new_computed_blocks, num_native_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(
request)
# Get externally-cached tokens if using a KVConnector.
if self.connector is not None:
num_external_computed_tokens, load_kv_async = (
self.connector.get_num_new_matched_tokens(
request, num_native_computed_tokens))
# Total computed tokens (local + external).
num_computed_tokens = (num_native_computed_tokens +
num_external_computed_tokens)
else:
# P/D: skip checking prefix cache if loaded from remote kvs.
new_computed_blocks = KVCacheBlocks.create_empty()
num_native_computed_tokens = 0
# Total computed tokens (allocated in prior step).
num_computed_tokens = num_prealloc_computed_tokens
# P/D: loading remote KV, do not allocate for new work.
if load_kv_async:
assert num_external_computed_tokens > 0
num_new_tokens = 0
blocks = None
# Number of tokens to be scheduled.
else:
prompt_limit = self._get_prompt_limit(request) prompt_limit = self._get_prompt_limit(request)
# Get already-cached tokens. # Get already-cached tokens.
computed_blocks, num_computed_tokens = ( computed_blocks, num_computed_tokens = (
self.kv_cache_manager.get_computed_blocks(request)) self.kv_cache_manager.get_computed_blocks(request))
# We use `request.num_tokens` instead of
# `request.num_prompt_tokens` to consider the resumed
# requests, which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens num_new_tokens = request.num_tokens - num_computed_tokens
if (0 < self.scheduler_config.long_prefill_token_threshold <
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
max_tokens_in_kvcache = (self.kv_cache_config.num_blocks * max_tokens_in_kvcache = (self.kv_cache_config.num_blocks *
self.block_size) self.block_size)
prompt_limit = min(prompt_limit, max_tokens_in_kvcache) prompt_limit = min(prompt_limit, max_tokens_in_kvcache)
@@ -117,7 +162,8 @@ class AscendScheduler(Scheduler):
prompt_limit, prompt_limit,
) )
request.status = RequestStatus.FINISHED_IGNORED request.status = RequestStatus.FINISHED_IGNORED
self.finished_req_ids.add(request.request_id) # type: ignore self.finished_req_ids.add( # type: ignore
request.request_id) # type: ignore
self.waiting.popleft() self.waiting.popleft()
continue continue
@@ -125,9 +171,9 @@ class AscendScheduler(Scheduler):
# Scheduling would exceed token_budget, skip. # Scheduling would exceed token_budget, skip.
skip_cur_request() skip_cur_request()
continue continue
assert num_new_tokens > 0 assert num_new_tokens > 0
blocks = computed_blocks.blocks[0] blocks = computed_blocks.blocks[0]
watermark = getattr(self.scheduler_config, "watermark", 0.01) watermark = getattr(self.scheduler_config, "watermark", 0.01)
if not self._check_watermark_for_prefill(request, num_new_tokens, if not self._check_watermark_for_prefill(request, num_new_tokens,
blocks, watermark): blocks, watermark):
@@ -136,13 +182,38 @@ class AscendScheduler(Scheduler):
continue continue
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, new_computed_blocks=computed_blocks) request,
num_new_tokens + num_external_computed_tokens,
num_native_computed_tokens,
new_computed_blocks=computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
break break
# KVConnector: update internal state after allocation.
# This information is used to determine if a load is
# needed for this request.
if num_external_computed_tokens:
assert self.connector is not None
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)
self.waiting.popleft() self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue
self.running.append(request) self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)
self.scheduled_req_ids.add(request.request_id) self.scheduled_req_ids.add(request.request_id)
# Check request status. # Check request status.
if request.status == RequestStatus.WAITING: if request.status == RequestStatus.WAITING:
@@ -161,6 +232,9 @@ class AscendScheduler(Scheduler):
token_budget -= num_new_tokens token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens request.num_computed_tokens = num_computed_tokens
# Count the number of prifix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Put back any skipped requests at the head of the waiting queue # Put back any skipped requests at the head of the waiting queue
if skipped_waiting_requests: if skipped_waiting_requests:
@@ -179,16 +253,45 @@ class AscendScheduler(Scheduler):
num_new_tokens = (request.num_tokens_with_spec - num_new_tokens = (request.num_tokens_with_spec -
request.num_computed_tokens) request.num_computed_tokens)
if (0 < self.scheduler_config.long_prefill_token_threshold < assert (request.num_tokens - request.num_computed_tokens) == 1
num_new_tokens):
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)
num_new_tokens = min(num_new_tokens, token_budget) num_new_tokens = min(num_new_tokens, token_budget)
assert num_new_tokens == 1 # Make sure the input position does not exceed the max model len.
# This is necessary when using spec decoding.
num_new_tokens = min(
num_new_tokens,
self.max_model_len - request.num_computed_tokens)
# Check that adding the request still respects the max_loras
# constraint.
if self.lora_config and request.lora_request and (
len(scheduled_loras) == self.lora_config.max_loras
and request.lora_request.lora_int_id
not in scheduled_loras):
# Scheduling would exceed max_loras, skip.
num_new_tokens = 0
if num_new_tokens == 0:
# The request cannot be scheduled because one of the following
# reason:
# 1. No new tokens to schedule. This may happen when PP>1 and
# we have already scheduled all prompt tokens but they are
# not finished yet.
# 2. Adding the request exceeds the max_loras constraint.
# NOTE(woosuk): Here, by doing `continue` instead of `break`,
# we do not strictly follow the FCFS scheduling policy and
# allow the lower-priority requests to be scheduled.
req_index += 1
continue
num_draft_tokens = max(
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
while True: while True:
new_blocks = self.kv_cache_manager.allocate_slots( new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens) request,
num_new_tokens,
num_draft_tokens=num_draft_tokens,
num_lookahead_tokens=self.num_lookahead_tokens)
if new_blocks is None: if new_blocks is None:
# The request cannot be scheduled. # The request cannot be scheduled.
# Preempt the lowest-priority request. # Preempt the lowest-priority request.
@@ -196,6 +299,10 @@ class AscendScheduler(Scheduler):
self.kv_cache_manager.free(preempted_req) self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0 preempted_req.num_computed_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED,
scheduled_timestamp)
self.waiting.appendleft(preempted_req) self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req) preempted_reqs.append(preempted_req)
if preempted_req == request: if preempted_req == request:
@@ -230,6 +337,10 @@ class AscendScheduler(Scheduler):
scheduled_spec_decode_tokens[request.request_id] = ( scheduled_spec_decode_tokens[request.request_id] = (
request.spec_token_ids) request.spec_token_ids)
# Record scheduled LoRA requests.
if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
# Check if the scheduling constraints are satisfied. # Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
@@ -297,6 +408,11 @@ class AscendScheduler(Scheduler):
meta = self.connector.build_connector_meta(scheduler_output) meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
# Advance the number of computed tokens for the request AFTER # Advance the number of computed tokens for the request AFTER
# the request is scheduled. # the request is scheduled.
# 1. The scheduler_output of the current step has to include the # 1. The scheduler_output of the current step has to include the
@@ -388,6 +504,7 @@ class AscendScheduler(Scheduler):
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
# The request was not scheduled in this step. # The request was not scheduled in this step.
continue continue
if req_id in self.scheduled_req_ids:
self.scheduled_req_ids.remove(req_id) self.scheduled_req_ids.remove(req_id)
return super().update_from_output(scheduler_output, return super().update_from_output(scheduler_output,