init v0.11.0rc0

This commit is contained in:
2025-10-14 10:38:28 +08:00
parent 67afd0ea78
commit 66dc16f966
278 changed files with 28130 additions and 11708 deletions

View File

@@ -6,25 +6,21 @@ from unittest.mock import MagicMock, patch
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem, PlaceholderRange)
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from tests.ut.base import TestBase
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
from vllm.v1.outputs import DraftTokenIds
else:
DraftTokenIds = None
EOS_TOKEN_ID = 50256
MODEL = "Qwen3-0.6B"
@@ -44,7 +40,7 @@ def create_requests(
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
block_size: int = 3,
hash_fn=hash,
hash_fn=sha256,
):
init_none_hash(hash_fn)
prompt_logprobs = PROMPT_LOGPROBS
@@ -54,25 +50,25 @@ def create_requests(
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
else:
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
mm_features = []
if mm_positions is not None:
mm_position = mm_positions[i]
for j, position in enumerate(mm_position):
identifier = f"hash{i}_{j}"
mm_feature = MultiModalFeatureSpec(
data=MultiModalKwargsItem.dummy("dummy_m"),
mm_position=position,
identifier=identifier,
modality="image")
mm_features.append(mm_feature)
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
mm_features=mm_features if mm_features else None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
requests.append(request)
return requests
@@ -85,25 +81,15 @@ def make_output(scheduler):
}
sampled_token_ids = [[1000]] * len(scheduler.running)
logprobs = None
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
else:
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
return modelrunner_output
@@ -113,7 +99,7 @@ class TestAscendScheduler(TestBase):
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [10, 20]
mock_compute_encoder_budget.return_value = [100, 100]
use_kv_connector = False
block_size = 16
@@ -235,7 +221,7 @@ class TestAscendScheduler(TestBase):
len(requests) - i - 1)
def test_schedule(self):
'''Test scheduling.
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
@@ -260,6 +246,60 @@ class TestAscendScheduler(TestBase):
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_multimodal_requests(self):
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
mm_positions = [[PlaceholderRange(offset=i, length=10)]
for i in range(10)]
requests = create_requests(
num_requests=10,
mm_positions=mm_positions,
)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
for req_id, num_tokens in output.num_scheduled_tokens.items():
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
assert len(encoder_input) == 1
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_concurrent_partial_prefills_schedule(self):
'''Test concurrent partial prefills scheduling.
total requests = 10, every request has 10 token.
while set long_prefill_token_threshold = 1, scheduler can
only schedule max_long_partial_prefills long request.
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
scheduler.scheduler_config.max_long_partial_prefills = 2
scheduler.scheduler_config.long_prefill_token_threshold = 1
requests = create_requests(num_requests=10, num_tokens=20)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs),
scheduler.scheduler_config.max_long_partial_prefills)
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
@@ -304,69 +344,34 @@ class TestAscendScheduler(TestBase):
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@@ -391,67 +396,35 @@ class TestAscendScheduler(TestBase):
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@@ -475,67 +448,35 @@ class TestAscendScheduler(TestBase):
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id:
[10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
@@ -556,52 +497,27 @@ class TestAscendScheduler(TestBase):
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
@@ -652,23 +568,13 @@ class TestAscendScheduler(TestBase):
512)
# Model output of the first request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0,
model_runner_output)
@@ -678,23 +584,13 @@ class TestAscendScheduler(TestBase):
# request is still running.
scheduler.schedule()
# Model output of the second request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1,
model_runner_output)
@@ -746,29 +642,19 @@ class TestAscendScheduler(TestBase):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
spec_token_ids=spec_tokens,
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
scheduler.update_draft_token_ids(draft_token_ids)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
@@ -804,23 +690,14 @@ class TestAscendScheduler(TestBase):
else:
self.assertNotIn(req_id,
output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
@@ -896,3 +773,34 @@ class TestAscendScheduler(TestBase):
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)
def test_scheduler_with_pd_transfer(self):
scheduler = self.create_scheduler()
scheduler.phase = "prefill"
requests = create_requests(num_requests=32)
for request in requests:
scheduler.add_request(request)
# 1st iteration, move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
first_iter_prefilled_req_num = len(scheduler.running)
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
scheduler.max_num_running_reqs)
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
# and move 16 requests from waiting to running for prefill
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(len(scheduler.finished_prefill_reqs),
first_iter_prefilled_req_num)
# 3rd iteration, all requests prefilled, change scheduler phase to decode
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
self.assertEqual(scheduler.phase, "decode")