[Test] Add uts for files in /core (#1957)

### What this PR does / why we need it?

Add uts for files in folder /core

### Does this PR introduce _any_ user-facing change?

No

- vLLM version: v0.9.2
- vLLM main:
5a19a6c670

---------

Signed-off-by: lwq <liwenquan5@huawei.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
This commit is contained in:
SunnyLee151064
2025-07-25 09:48:19 +08:00
committed by GitHub
parent 6bc82cf6a7
commit ae560f7131
2 changed files with 718 additions and 730 deletions

View File

@@ -1,743 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import pytest
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
MODEL = "Qwen/Qwen3-0.6B"
def create_scheduler(
model: str = MODEL,
max_num_seqs: int = 16,
max_num_batched_tokens: int = 8192,
enable_prefix_caching: Optional[bool] = None,
long_prefill_token_threshold: int = 0,
disable_chunked_mm_input: bool = False,
use_kv_connector: bool = False,
num_blocks: int = 10000,
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
enable_chunked_prefill: bool = False,
) -> AscendScheduler:
'''Create scheduler under test.
Args:
model: model under test
max_num_seqs: max sequences to schedule
max_num_batch_tokens: max num tokens to batch
enable_prefix_caching: optionally force APC config
(True/False) or use default
(None)
Returns:
{class}`Scheduler` instance
'''
if max_model_len is None:
max_model_len = max_num_batched_tokens
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
long_prefill_token_threshold=long_prefill_token_threshold,
disable_chunked_mm_input=disable_chunked_mm_input,
enable_chunked_prefill=enable_chunked_prefill,
)
model_config = ModelConfig(
model=model,
task="auto",
tokenizer=model,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
'enable_prefix_caching': enable_prefix_caching
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if num_speculative_tokens is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=num_speculative_tokens)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
],
)
cache_config.num_gpu_blocks = num_blocks
return AscendScheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
def create_requests(num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[int] = None):
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if mm_positions is not None:
mm_position = mm_positions[i]
mm_inputs = [MultiModalKwargs({})] * len(mm_position)
else:
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
)
requests.append(request)
return requests
def test_add_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
assert request.request_id in scheduler.requests
assert len(scheduler.waiting) == i + 1
def test_finish_request():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
assert request.request_id not in scheduler.requests
assert len(scheduler.waiting) == 9 - i
def test_get_num_unfinished_requests():
scheduler = create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
assert scheduler.get_num_unfinished_requests() == len(requests) - i - 1
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
requests = create_requests(num_requests=10,
prompt_logprobs=prompt_logprobs)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify requests moved from waiting to running
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == len(requests)
for i, request in enumerate(requests):
assert scheduler.running[i] == request
@pytest.mark.parametrize("enable_prefix_caching", [True, False])
def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
"""Test scheduling behavior with concurrent partial requests.
This test verifies that: there are multiple long prefill requests in the
RUNNING state, and we can schedule them together.
"""
scheduler = create_scheduler(
model="facebook/opt-125m",
max_num_batched_tokens=1024,
long_prefill_token_threshold=400,
enable_prefix_caching=enable_prefix_caching,
enable_chunked_prefill=True,
)
requests = create_requests(
num_requests=3,
num_tokens=800,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[0].request_id] == 400
# The second request is scheduled partially - 400.
assert output.num_scheduled_tokens[requests[1].request_id] == 400
# The third request is also scheduled partially - 1024 - 400 - 400 = 224.
assert output.num_scheduled_tokens[requests[2].request_id] == 224
req_to_index = {
request.request_id: i
for i, request in enumerate(requests)
}
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(output, model_runner_output)
# Schedule the next step. All three requests are running.
# Processed the remaining prefills of the first and second requests.
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
assert output1.num_scheduled_tokens[requests[2].request_id] == 224
# Schedule the third step. All three requests are running.
# First and second requests are in the decode stage.
# All the remaining tokens in the third request are processed.
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(output1, model_runner_output)
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
assert output2.num_scheduled_tokens[
requests[2].request_id] == 800 - 224 - 224
def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler(num_speculative_tokens=1)
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]
# Test case 2: Stop on custom stop token
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]
# Test case 3: Stop on max tokens
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]
# Test case 4: Ignore EOS flag
scheduler = create_scheduler(num_speculative_tokens=2)
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]
@pytest.mark.parametrize("enable_prefix_caching, prompt_logprobs", [
(None, None),
(True, 5),
])
def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
scheduler = create_scheduler(
max_num_batched_tokens=1024,
max_num_seqs=2,
enable_prefix_caching=enable_prefix_caching,
enable_chunked_prefill=True,
)
requests = create_requests(
num_requests=2,
num_tokens=512,
prompt_logprobs=prompt_logprobs,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
assert len(scheduler_output0.scheduled_new_reqs) == 1
assert scheduler_output0.num_scheduled_tokens[
requests[0].request_id] == 512
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
assert len(scheduler_output1.scheduled_new_reqs) == 1
assert scheduler_output1.num_scheduled_tokens[
requests[1].request_id] == 512
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0, model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler_output2 = scheduler.schedule()
assert scheduler_output2.num_scheduled_tokens[requests[0].request_id] == 1
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1, model_runner_output)
# Note - these test cases mirror some of those in test_rejection_sampler.py
@pytest.mark.parametrize(
"spec_tokens,output_tokens,expected",
[
([[1, 2, 3]], [[1, 2, 3, 4]], (1, 3, 3, [1, 1, 1])), # perfect match
([[1, 2, 3]], [[1, 5]], (1, 3, 1, [1, 0, 0])), # early mismatch
([[1, 2], [3]], [[1, 2, 5], [3, 4]],
(2, 3, 3, [2, 1])), # multiple sequences
([[1]], [[1, 2]], (1, 1, 1, [1])), # single token sequence
([[]], [[5]], (0, 0, 0, [0])), # empty sequence
([[1, 2, 3], [4, 5, 6]], [[1, 2, 7], [4, 8]],
(2, 6, 3, [2, 1, 0])), # multiple mismatches
])
def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
scheduler = create_scheduler(num_speculative_tokens=num_spec_tokens)
requests = create_requests(num_requests=len(spec_tokens), num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.total_num_scheduled_tokens == len(requests)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
assert running_req.num_computed_tokens == 1
# The prompt token and the sampled token
assert running_req.num_tokens == 2
# The prompt token, the sampled token, and the speculated tokens
assert running_req.num_tokens_with_spec == 2 + len(spec_tokens[i])
# No draft or accepted tokens counted yet
assert not engine_core_outputs or (
engine_core_outputs[0].scheduler_stats.spec_decoding_stats is None)
# Schedule the speculated tokens for validation
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
# The sampled token and speculated tokens
assert output.total_num_scheduled_tokens == \
len(requests) + sum(len(ids) for ids in spec_tokens)
for i in range(len(requests)):
req_id = requests[i].request_id
assert output.num_scheduled_tokens[req_id] == 1 + len(spec_tokens[i])
if spec_tokens[i]:
assert len(output.scheduled_spec_decode_tokens[req_id]) == \
len(spec_tokens[i])
else:
assert req_id not in output.scheduled_spec_decode_tokens
model_runner_output = ModelRunnerOutput(req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
assert scheduler_stats.spec_decoding_stats is None # type: ignore
else:
assert scheduler_stats.spec_decoding_stats is not None # type: ignore
stats = scheduler_stats.spec_decoding_stats # type: ignore
assert stats.num_drafts == expected[0]
assert stats.num_draft_tokens == expected[1]
assert stats.num_accepted_tokens == expected[2]
assert stats.num_accepted_tokens_per_pos == expected[3]
def make_output(scheduler: AscendScheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={
req.request_id: i
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
def assert_scheduler_empty(scheduler: AscendScheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def test_memory_leak():
"""Test that we do not have a memory leak."""
scheduler = create_scheduler(enable_prefix_caching=True)
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)
def test_concurrent_partial_prefill():
with VllmRunner(MODEL,
additional_config={

View File

@@ -0,0 +1,718 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import MagicMock, patch
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from tests.ut.base import TestBase
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
MODEL = "Qwen3-0.6B"
ENABLE_PREFIX_CACHING = None
PROMPT_LOGPROBS = None
ENABLE_CHUNKED_PREFILL = False
MAX_NUM_BATCHED_TOKENS = 10000
LONG_PREFILL_TOKEN_THRESHOLD = 0
NUM_SPECULATIVE_TOKENS = None
MAX_NUM_SEQS = 16
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
):
prompt_logprobs = PROMPT_LOGPROBS
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
mm_position = None
mm_inputs = None
request = Request(
request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
)
requests.append(request)
return requests
def make_output(scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={
req.request_id: i
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
class TestAscendScheduler(TestBase):
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [10, 20]
use_kv_connector = False
block_size = 16
scheduler_config = SchedulerConfig(
max_num_seqs=16,
max_model_len=MAX_NUM_BATCHED_TOKENS,
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
disable_chunked_mm_input=False,
enable_chunked_prefill=ENABLE_CHUNKED_PREFILL,
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
)
scheduler_config.max_num_encoder_input_tokens = 10000
scheduler_config.encoder_cache_size = 10000
scheduler_config.chunked_prefill_enabled = False
model_config = ModelConfig(
model=MODEL,
task="auto",
tokenizer=MODEL,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
max_model_len=MAX_NUM_BATCHED_TOKENS,
)
model_config.pooler_config = MagicMock()
model_config.multimodal_config = MagicMock()
# Cache config, optionally force APC
kwargs_cache: Dict[str,
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
'enable_prefix_caching':
ENABLE_PREFIX_CACHING
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if NUM_SPECULATIVE_TOKENS is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1,
torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
scheduler = AscendScheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=MagicMock(spec=StructuredOutputManager),
)
should_advance = MagicMock()
should_advance.return_value = False
scheduler.structured_output_manager.should_advance = should_advance
return scheduler
def test_add_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
self.assertIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), i + 1)
def test_finish_request(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
self.assertNotIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), 9 - i)
def test_get_num_unfinished_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
self.assertEqual(scheduler.get_num_unfinished_requests(),
len(requests) - i - 1)
def test_schedule(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
global ENABLE_PREFIX_CACHING
ENABLE_PREFIX_CACHING = True
global PROMPT_LOGPROBS
PROMPT_LOGPROBS = 5
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_stop_via_update_from_output(self):
"""Test stopping behavior through update_from_output"""
global NUM_SPECULATIVE_TOKENS
NUM_SPECULATIVE_TOKENS = 1
scheduler = self.create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
# Test case 2: Stop on custom stop token
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertEqual(requests[0].stop_reason, 42)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
# Test case 3: Stop on max tokens
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
if not vllm_version_is("0.9.2"):
req.status = RequestStatus.RUNNING
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status,
RequestStatus.FINISHED_LENGTH_CAPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
self.assertEqual(list(requests[1].output_token_ids), [13])
# Test case 4: Ignore EOS flag
scheduler = self.create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
self.assertEqual(len(scheduler.running), 1)
self.assertFalse(requests[0].is_finished())
self.assertEqual(list(requests[0].output_token_ids),
[EOS_TOKEN_ID, 10, 11])
def test_schedule_concurrent_batches(self):
global MAX_NUM_BATCHED_TOKENS
global ENABLE_PREFIX_CACHING
global ENABLE_CHUNKED_PREFILL
global MAX_NUM_SEQS
global PROMPT_LOGPROBS
ENABLE_PREFIX_CACHING = None
MAX_NUM_BATCHED_TOKENS = 1024
MAX_NUM_SEQS = 2
ENABLE_CHUNKED_PREFILL = True
PROMPT_LOGPROBS = None
enable_prefix_caching_list = [None, True]
prompt_logprobs_list = [None, 5]
for i in range(len(enable_prefix_caching_list)):
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
PROMPT_LOGPROBS = prompt_logprobs_list[i]
scheduler = self.create_scheduler()
requests = create_requests(
num_requests=2,
num_tokens=512,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
512)
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
512)
# Model output of the first request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0,
model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler.schedule()
# Model output of the second request.
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1,
model_runner_output)
def test_schedule_spec_decoding_stats(self):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
[[1, 2], [3]], [[1]], [[]],
[[1, 2, 3], [4, 5, 6]]]
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
[[1, 2, 5], [3, 4]],
[[1, 2]], [[5]],
[[1, 2, 7], [4, 8]]]
expected_list: List[Tuple[int, int,
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
(1, 3, 1, [1, 0, 0]),
(2, 3, 3, [2, 1]),
(1, 1, 1, [1]),
(0, 0, 0, [0]),
(2, 6, 3, [2, 1, 0])]
global NUM_SPECULATIVE_TOKENS
for idx in range(len(spec_tokens_list)):
spec_tokens = spec_tokens_list[idx]
output_tokens = output_tokens_list[idx]
expected = expected_list[idx]
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
NUM_SPECULATIVE_TOKENS = num_spec_tokens
scheduler = self.create_scheduler()
requests = create_requests(num_requests=len(spec_tokens),
num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
self.assertEqual(running_req.num_computed_tokens, 1)
# The prompt token and the sampled token
self.assertEqual(running_req.num_tokens, 2)
# The prompt token, the sampled token, and the speculated tokens
self.assertEqual(running_req.num_tokens_with_spec,
2 + len(spec_tokens[i]))
# No draft or accepted tokens counted yet
self.assertTrue(
not engine_core_outputs
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
is None))
# Schedule the speculated tokens for validation
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), 0)
# The sampled token and speculated tokens
self.assertEqual(
output.total_num_scheduled_tokens,
len(requests) + sum(len(ids) for ids in spec_tokens))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id],
1 + len(spec_tokens[i]))
if spec_tokens[i]:
self.assertEqual(
len(output.scheduled_spec_decode_tokens[req_id]),
len(spec_tokens[i]))
else:
self.assertNotIn(req_id,
output.scheduled_spec_decode_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
self.assertIsNone(scheduler_stats.spec_decoding_stats)
else:
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
stats = scheduler_stats.spec_decoding_stats
self.assertEqual(stats.num_drafts, expected[0])
self.assertEqual(stats.num_draft_tokens, expected[1])
self.assertEqual(stats.num_accepted_tokens, expected[2])
self.assertEqual(stats.num_accepted_tokens_per_pos,
expected[3])
def assert_scheduler_empty(self, scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
scheduler = self.create_scheduler()
self.assertEqual(len(scheduler.requests), 0)
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), 0)
self.assertEqual(len(scheduler.finished_req_ids), 0)
# EncoderCacheManager.
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
# KVCache Manager.
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks), 0)
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block), 0)
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
0)
self.assertEqual(len(scheduler.kv_cache_manager.req_to_block_hashes),
0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
free_block_queue.num_free_blocks)
self.assertEqual(
num_free_blocks,
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
self.assertEqual(block.ref_cnt, 0)
def test_memory_leak(self):
"""Test that we do not have a memory leak."""
scheduler = self.create_scheduler()
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)