[Test] Add uts for files in /core (#1957)
### What this PR does / why we need it?
Add uts for files in folder /core
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.9.2
- vLLM main:
5a19a6c670
---------
Signed-off-by: lwq <liwenquan5@huawei.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
This commit is contained in:
@@ -1,743 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.model_utils import check_outputs_equal
|
||||
from vllm_ascend.core.scheduler import AscendScheduler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
MODEL = "Qwen/Qwen3-0.6B"
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
model: str = MODEL,
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 8192,
|
||||
enable_prefix_caching: Optional[bool] = None,
|
||||
long_prefill_token_threshold: int = 0,
|
||||
disable_chunked_mm_input: bool = False,
|
||||
use_kv_connector: bool = False,
|
||||
num_blocks: int = 10000,
|
||||
block_size: int = 16,
|
||||
max_model_len: Optional[int] = None,
|
||||
num_speculative_tokens: Optional[int] = None,
|
||||
enable_chunked_prefill: bool = False,
|
||||
) -> AscendScheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
Args:
|
||||
model: model under test
|
||||
max_num_seqs: max sequences to schedule
|
||||
max_num_batch_tokens: max num tokens to batch
|
||||
enable_prefix_caching: optionally force APC config
|
||||
(True/False) or use default
|
||||
(None)
|
||||
|
||||
Returns:
|
||||
{class}`Scheduler` instance
|
||||
'''
|
||||
if max_model_len is None:
|
||||
max_model_len = max_num_batched_tokens
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_model_len,
|
||||
long_prefill_token_threshold=long_prefill_token_threshold,
|
||||
disable_chunked_mm_input=disable_chunked_mm_input,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
)
|
||||
model_config = ModelConfig(
|
||||
model=model,
|
||||
task="auto",
|
||||
tokenizer=model,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
'enable_prefix_caching': enable_prefix_caching
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if num_speculative_tokens is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=num_speculative_tokens)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = num_blocks
|
||||
return AscendScheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
def create_requests(num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[int] = None):
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
|
||||
else:
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
|
||||
def test_add_requests():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
assert request.request_id in scheduler.requests
|
||||
assert len(scheduler.waiting) == i + 1
|
||||
|
||||
|
||||
def test_finish_request():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
assert request.request_id not in scheduler.requests
|
||||
assert len(scheduler.waiting) == 9 - i
|
||||
|
||||
|
||||
def test_get_num_unfinished_requests():
|
||||
scheduler = create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_STOPPED)
|
||||
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||
requests = create_requests(num_requests=10,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == len(requests)
|
||||
for i, request in enumerate(requests):
|
||||
assert scheduler.running[i] == request
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
|
||||
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
"""Test scheduling behavior with concurrent partial requests.
|
||||
|
||||
This test verifies that: there are multiple long prefill requests in the
|
||||
RUNNING state, and we can schedule them together.
|
||||
|
||||
"""
|
||||
scheduler = create_scheduler(
|
||||
model="facebook/opt-125m",
|
||||
max_num_batched_tokens=1024,
|
||||
long_prefill_token_threshold=400,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=3,
|
||||
num_tokens=800,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 3
|
||||
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(output.finished_req_ids) == 0
|
||||
|
||||
# The first request is scheduled partially - 400.
|
||||
assert output.num_scheduled_tokens[requests[0].request_id] == 400
|
||||
# The second request is scheduled partially - 400.
|
||||
assert output.num_scheduled_tokens[requests[1].request_id] == 400
|
||||
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
|
||||
assert output.num_scheduled_tokens[requests[2].request_id] == 224
|
||||
req_to_index = {
|
||||
request.request_id: i
|
||||
for i, request in enumerate(requests)
|
||||
}
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler.update_from_output(output, model_runner_output)
|
||||
|
||||
# Schedule the next step. All three requests are running.
|
||||
# Processed the remaining prefills of the first and second requests.
|
||||
output1 = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(output1.scheduled_new_reqs) == 0
|
||||
assert output1.scheduled_cached_reqs.num_reqs == 3
|
||||
assert len(output1.finished_req_ids) == 0
|
||||
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
|
||||
|
||||
# Schedule the third step. All three requests are running.
|
||||
# First and second requests are in the decode stage.
|
||||
# All the remaining tokens in the third request are processed.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(output1, model_runner_output)
|
||||
output2 = scheduler.schedule()
|
||||
assert len(scheduler.running) == 3
|
||||
assert len(output2.scheduled_new_reqs) == 0
|
||||
assert output2.scheduled_cached_reqs.num_reqs == 3
|
||||
assert len(output2.finished_req_ids) == 0
|
||||
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||
assert output2.num_scheduled_tokens[
|
||||
requests[2].request_id] == 800 - 224 - 224
|
||||
|
||||
|
||||
def test_stop_via_update_from_output():
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
scheduler = create_scheduler(num_speculative_tokens=1)
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||
[10,
|
||||
11]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped, second continues
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
|
||||
assert list(requests[1].output_token_ids) == [10, 11]
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped on custom token
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_STOPPED
|
||||
assert requests[0].stop_reason == 42
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 42]
|
||||
assert list(requests[1].output_token_ids) == [13, 14]
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
assert len(scheduler.running) == 1
|
||||
assert scheduler.running[0].request_id == requests[1].request_id
|
||||
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
assert requests[0].request_id in scheduler.finished_req_ids
|
||||
assert list(requests[0].output_token_ids) == [10, 11
|
||||
] # Truncated to max_tokens
|
||||
assert list(requests[1].output_token_ids) == [13]
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = create_scheduler(num_speculative_tokens=2)
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify request continues past EOS
|
||||
assert len(scheduler.running) == 1
|
||||
assert not requests[0].is_finished()
|
||||
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
|
||||
(None, None),
|
||||
(True, 5),
|
||||
])
|
||||
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
scheduler = create_scheduler(
|
||||
max_num_batched_tokens=1024,
|
||||
max_num_seqs=2,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
prompt_logprobs=prompt_logprobs,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
scheduler_output0 = scheduler.schedule()
|
||||
assert len(scheduler_output0.scheduled_new_reqs) == 1
|
||||
assert scheduler_output0.num_scheduled_tokens[
|
||||
requests[0].request_id] == 512
|
||||
|
||||
# The first request is still running, so only schedule the second request.
|
||||
scheduler.add_request(requests[1])
|
||||
scheduler_output1 = scheduler.schedule()
|
||||
assert len(scheduler_output1.scheduled_new_reqs) == 1
|
||||
assert scheduler_output1.num_scheduled_tokens[
|
||||
requests[1].request_id] == 512
|
||||
|
||||
# Model output of the first request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output0, model_runner_output)
|
||||
|
||||
# Schedule the next step.
|
||||
# The first request can be scheduled again while the second
|
||||
# request is still running.
|
||||
scheduler_output2 = scheduler.schedule()
|
||||
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||
|
||||
# Model output of the second request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output1, model_runner_output)
|
||||
|
||||
|
||||
# Note - these test cases mirror some of those in test_rejection_sampler.py
|
||||
@pytest.mark.parametrize(
|
||||
"spec_tokens,output_tokens,expected",
|
||||
[
|
||||
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
|
||||
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
|
||||
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
|
||||
(2, 3, 3, [2, 1])), # multiple sequences
|
||||
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
|
||||
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
|
||||
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
|
||||
(2, 6, 3, [2, 1, 0])), # multiple mismatches
|
||||
])
|
||||
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
"""Test scheduling behavior with speculative decoding.
|
||||
|
||||
This test verifies that:
|
||||
1. Speculated tokens get scheduled correctly
|
||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||
"""
|
||||
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
|
||||
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
# Schedule a decode, which will also draft speculative tokens
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert output.total_num_scheduled_tokens == len(requests)
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
assert output.num_scheduled_tokens[req_id] == 1
|
||||
assert req_id not in output.scheduled_spec_decode_tokens
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
# The prompt token
|
||||
assert running_req.num_computed_tokens == 1
|
||||
# The prompt token and the sampled token
|
||||
assert running_req.num_tokens == 2
|
||||
# The prompt token, the sampled token, and the speculated tokens
|
||||
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
assert not engine_core_outputs or (
|
||||
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
# The sampled token and speculated tokens
|
||||
assert output.total_num_scheduled_tokens == \
|
||||
len(requests) + sum(len(ids) for ids in spec_tokens)
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
|
||||
if spec_tokens[i]:
|
||||
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
|
||||
len(spec_tokens[i])
|
||||
else:
|
||||
assert req_id not in output.scheduled_spec_decode_tokens
|
||||
|
||||
model_runner_output = ModelRunnerOutput(req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
|
||||
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||
if engine_core_outputs else None
|
||||
if expected[0] == 0:
|
||||
assert scheduler_stats.spec_decoding_stats is None # type: ignore
|
||||
else:
|
||||
assert scheduler_stats.spec_decoding_stats is not None # type: ignore
|
||||
stats = scheduler_stats.spec_decoding_stats # type: ignore
|
||||
assert stats.num_drafts == expected[0]
|
||||
assert stats.num_draft_tokens == expected[1]
|
||||
assert stats.num_accepted_tokens == expected[2]
|
||||
assert stats.num_accepted_tokens_per_pos == expected[3]
|
||||
|
||||
|
||||
def make_output(scheduler: AscendScheduler):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in scheduler.running],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(scheduler.running)
|
||||
},
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: AscendScheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def test_memory_leak():
|
||||
"""Test that we do not have a memory leak."""
|
||||
|
||||
scheduler = create_scheduler(enable_prefix_caching=True)
|
||||
|
||||
NUM_REQUESTS = 5
|
||||
NUM_TOKENS = 10
|
||||
MAX_TOKENS = 10
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
|
||||
# Add each request.
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Iterate until done.
|
||||
while True:
|
||||
scheduler_output = scheduler.schedule()
|
||||
if len(scheduler.running) == 0:
|
||||
break
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm no memory leak.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_concurrent_partial_prefill():
|
||||
with VllmRunner(MODEL,
|
||||
additional_config={
|
||||
|
||||
718
tests/ut/core/test_scheduler.py
Normal file
718
tests/ut/core/test_scheduler.py
Normal file
@@ -0,0 +1,718 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.core.scheduler import AscendScheduler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
MODEL = "Qwen3-0.6B"
|
||||
ENABLE_PREFIX_CACHING = None
|
||||
PROMPT_LOGPROBS = None
|
||||
ENABLE_CHUNKED_PREFILL = False
|
||||
MAX_NUM_BATCHED_TOKENS = 10000
|
||||
LONG_PREFILL_TOKEN_THRESHOLD = 0
|
||||
NUM_SPECULATIVE_TOKENS = None
|
||||
MAX_NUM_SEQS = 16
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
):
|
||||
prompt_logprobs = PROMPT_LOGPROBS
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
mm_position = None
|
||||
mm_inputs = None
|
||||
request = Request(
|
||||
request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_inputs=mm_inputs,
|
||||
multi_modal_placeholders=mm_position,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
)
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
|
||||
def make_output(scheduler):
|
||||
return ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in scheduler.running],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(scheduler.running)
|
||||
},
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
|
||||
class TestAscendScheduler(TestBase):
|
||||
|
||||
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
|
||||
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
||||
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
||||
def create_scheduler(self, mock_compute_encoder_budget):
|
||||
mock_compute_encoder_budget.return_value = [10, 20]
|
||||
use_kv_connector = False
|
||||
block_size = 16
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=16,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKENS,
|
||||
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
|
||||
disable_chunked_mm_input=False,
|
||||
enable_chunked_prefill=ENABLE_CHUNKED_PREFILL,
|
||||
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
scheduler_config.max_num_encoder_input_tokens = 10000
|
||||
scheduler_config.encoder_cache_size = 10000
|
||||
scheduler_config.chunked_prefill_enabled = False
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL,
|
||||
task="auto",
|
||||
tokenizer=MODEL,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
model_config.pooler_config = MagicMock()
|
||||
model_config.multimodal_config = MagicMock()
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache: Dict[str,
|
||||
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
|
||||
'enable_prefix_caching':
|
||||
ENABLE_PREFIX_CACHING
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if NUM_SPECULATIVE_TOKENS is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10000, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1,
|
||||
torch.float32, False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
|
||||
scheduler = AscendScheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=MagicMock(spec=StructuredOutputManager),
|
||||
)
|
||||
|
||||
should_advance = MagicMock()
|
||||
should_advance.return_value = False
|
||||
scheduler.structured_output_manager.should_advance = should_advance
|
||||
|
||||
return scheduler
|
||||
|
||||
def test_add_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
self.assertIn(request.request_id, scheduler.requests)
|
||||
self.assertEqual(len(scheduler.waiting), i + 1)
|
||||
|
||||
def test_finish_request(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
self.assertNotIn(request.request_id, scheduler.requests)
|
||||
self.assertEqual(len(scheduler.waiting), 9 - i)
|
||||
|
||||
def test_get_num_unfinished_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_STOPPED)
|
||||
self.assertEqual(scheduler.get_num_unfinished_requests(),
|
||||
len(requests) - i - 1)
|
||||
|
||||
def test_schedule(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_schedule_enable_prefix_caching(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
global ENABLE_PREFIX_CACHING
|
||||
ENABLE_PREFIX_CACHING = True
|
||||
global PROMPT_LOGPROBS
|
||||
PROMPT_LOGPROBS = 5
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_stop_via_update_from_output(self):
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
global NUM_SPECULATIVE_TOKENS
|
||||
NUM_SPECULATIVE_TOKENS = 1
|
||||
scheduler = self.create_scheduler()
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
|
||||
], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped, second continues
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
NUM_SPECULATIVE_TOKENS = 2
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped on custom token
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
|
||||
self.assertEqual(requests[0].stop_reason, 42)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
NUM_SPECULATIVE_TOKENS = 2
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
if not vllm_version_is("0.9.2"):
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status,
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [13])
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify request continues past EOS
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertFalse(requests[0].is_finished())
|
||||
self.assertEqual(list(requests[0].output_token_ids),
|
||||
[EOS_TOKEN_ID, 10, 11])
|
||||
|
||||
def test_schedule_concurrent_batches(self):
|
||||
global MAX_NUM_BATCHED_TOKENS
|
||||
global ENABLE_PREFIX_CACHING
|
||||
global ENABLE_CHUNKED_PREFILL
|
||||
global MAX_NUM_SEQS
|
||||
global PROMPT_LOGPROBS
|
||||
ENABLE_PREFIX_CACHING = None
|
||||
MAX_NUM_BATCHED_TOKENS = 1024
|
||||
MAX_NUM_SEQS = 2
|
||||
ENABLE_CHUNKED_PREFILL = True
|
||||
PROMPT_LOGPROBS = None
|
||||
|
||||
enable_prefix_caching_list = [None, True]
|
||||
prompt_logprobs_list = [None, 5]
|
||||
|
||||
for i in range(len(enable_prefix_caching_list)):
|
||||
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
|
||||
PROMPT_LOGPROBS = prompt_logprobs_list[i]
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
scheduler_output0 = scheduler.schedule()
|
||||
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
|
||||
self.assertEqual(
|
||||
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
|
||||
512)
|
||||
|
||||
# The first request is still running, so only schedule the second request.
|
||||
scheduler.add_request(requests[1])
|
||||
scheduler_output1 = scheduler.schedule()
|
||||
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
|
||||
self.assertEqual(
|
||||
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
|
||||
512)
|
||||
|
||||
# Model output of the first request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output0,
|
||||
model_runner_output)
|
||||
|
||||
# Schedule the next step.
|
||||
# The first request can be scheduled again while the second
|
||||
# request is still running.
|
||||
scheduler.schedule()
|
||||
# Model output of the second request.
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output1,
|
||||
model_runner_output)
|
||||
|
||||
def test_schedule_spec_decoding_stats(self):
|
||||
"""Test scheduling behavior with speculative decoding.
|
||||
|
||||
This test verifies that:
|
||||
1. Speculated tokens get scheduled correctly
|
||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||
"""
|
||||
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
|
||||
[[1, 2], [3]], [[1]], [[]],
|
||||
[[1, 2, 3], [4, 5, 6]]]
|
||||
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
|
||||
[[1, 2, 5], [3, 4]],
|
||||
[[1, 2]], [[5]],
|
||||
[[1, 2, 7], [4, 8]]]
|
||||
expected_list: List[Tuple[int, int,
|
||||
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
|
||||
(1, 3, 1, [1, 0, 0]),
|
||||
(2, 3, 3, [2, 1]),
|
||||
(1, 1, 1, [1]),
|
||||
(0, 0, 0, [0]),
|
||||
(2, 6, 3, [2, 1, 0])]
|
||||
|
||||
global NUM_SPECULATIVE_TOKENS
|
||||
for idx in range(len(spec_tokens_list)):
|
||||
spec_tokens = spec_tokens_list[idx]
|
||||
output_tokens = output_tokens_list[idx]
|
||||
expected = expected_list[idx]
|
||||
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||
NUM_SPECULATIVE_TOKENS = num_spec_tokens
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=len(spec_tokens),
|
||||
num_tokens=1)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
# Schedule a decode, which will also draft speculative tokens
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
||||
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
# The prompt token
|
||||
self.assertEqual(running_req.num_computed_tokens, 1)
|
||||
# The prompt token and the sampled token
|
||||
self.assertEqual(running_req.num_tokens, 2)
|
||||
# The prompt token, the sampled token, and the speculated tokens
|
||||
self.assertEqual(running_req.num_tokens_with_spec,
|
||||
2 + len(spec_tokens[i]))
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
self.assertTrue(
|
||||
not engine_core_outputs
|
||||
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
|
||||
is None))
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), 0)
|
||||
# The sampled token and speculated tokens
|
||||
self.assertEqual(
|
||||
output.total_num_scheduled_tokens,
|
||||
len(requests) + sum(len(ids) for ids in spec_tokens))
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id],
|
||||
1 + len(spec_tokens[i]))
|
||||
if spec_tokens[i]:
|
||||
self.assertEqual(
|
||||
len(output.scheduled_spec_decode_tokens[req_id]),
|
||||
len(spec_tokens[i]))
|
||||
else:
|
||||
self.assertNotIn(req_id,
|
||||
output.scheduled_spec_decode_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
|
||||
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||
if engine_core_outputs else None
|
||||
if expected[0] == 0:
|
||||
self.assertIsNone(scheduler_stats.spec_decoding_stats)
|
||||
else:
|
||||
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
|
||||
stats = scheduler_stats.spec_decoding_stats
|
||||
self.assertEqual(stats.num_drafts, expected[0])
|
||||
self.assertEqual(stats.num_draft_tokens, expected[1])
|
||||
self.assertEqual(stats.num_accepted_tokens, expected[2])
|
||||
self.assertEqual(stats.num_accepted_tokens_per_pos,
|
||||
expected[3])
|
||||
|
||||
def assert_scheduler_empty(self, scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
scheduler = self.create_scheduler()
|
||||
self.assertEqual(len(scheduler.requests), 0)
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), 0)
|
||||
self.assertEqual(len(scheduler.finished_req_ids), 0)
|
||||
|
||||
# EncoderCacheManager.
|
||||
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
|
||||
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
|
||||
|
||||
# KVCache Manager.
|
||||
self.assertEqual(
|
||||
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks), 0)
|
||||
self.assertEqual(
|
||||
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block), 0)
|
||||
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
|
||||
0)
|
||||
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
|
||||
0)
|
||||
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
|
||||
free_block_queue.num_free_blocks)
|
||||
self.assertEqual(
|
||||
num_free_blocks,
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
self.assertEqual(block.ref_cnt, 0)
|
||||
|
||||
def test_memory_leak(self):
|
||||
"""Test that we do not have a memory leak."""
|
||||
scheduler = self.create_scheduler()
|
||||
NUM_REQUESTS = 5
|
||||
NUM_TOKENS = 10
|
||||
MAX_TOKENS = 10
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
|
||||
# Add each request.
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Iterate until done.
|
||||
while True:
|
||||
scheduler_output = scheduler.schedule()
|
||||
if len(scheduler.running) == 0:
|
||||
break
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm no memory leak.
|
||||
self.assert_scheduler_empty(scheduler)
|
||||
Reference in New Issue
Block a user