init v0.11.0rc0
This commit is contained in:
@@ -6,25 +6,21 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
|
||||
MultiModalKwargsItem, PlaceholderRange)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.core.scheduler import AscendScheduler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
else:
|
||||
DraftTokenIds = None
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
MODEL = "Qwen3-0.6B"
|
||||
@@ -44,7 +40,7 @@ def create_requests(
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
block_size: int = 3,
|
||||
hash_fn=hash,
|
||||
hash_fn=sha256,
|
||||
):
|
||||
init_none_hash(hash_fn)
|
||||
prompt_logprobs = PROMPT_LOGPROBS
|
||||
@@ -54,25 +50,25 @@ def create_requests(
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
else:
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
mm_features = []
|
||||
if mm_positions is not None:
|
||||
mm_position = mm_positions[i]
|
||||
for j, position in enumerate(mm_position):
|
||||
identifier = f"hash{i}_{j}"
|
||||
mm_feature = MultiModalFeatureSpec(
|
||||
data=MultiModalKwargsItem.dummy("dummy_m"),
|
||||
mm_position=position,
|
||||
identifier=identifier,
|
||||
modality="image")
|
||||
mm_features.append(mm_feature)
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
mm_features=mm_features if mm_features else None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
@@ -85,25 +81,15 @@ def make_output(scheduler):
|
||||
}
|
||||
sampled_token_ids = [[1000]] * len(scheduler.running)
|
||||
logprobs = None
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
else:
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
return modelrunner_output
|
||||
|
||||
|
||||
@@ -113,7 +99,7 @@ class TestAscendScheduler(TestBase):
|
||||
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
||||
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
||||
def create_scheduler(self, mock_compute_encoder_budget):
|
||||
mock_compute_encoder_budget.return_value = [10, 20]
|
||||
mock_compute_encoder_budget.return_value = [100, 100]
|
||||
use_kv_connector = False
|
||||
block_size = 16
|
||||
|
||||
@@ -235,7 +221,7 @@ class TestAscendScheduler(TestBase):
|
||||
len(requests) - i - 1)
|
||||
|
||||
def test_schedule(self):
|
||||
'''Test scheduling.
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
@@ -260,6 +246,60 @@ class TestAscendScheduler(TestBase):
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_schedule_multimodal_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
mm_positions = [[PlaceholderRange(offset=i, length=10)]
|
||||
for i in range(10)]
|
||||
requests = create_requests(
|
||||
num_requests=10,
|
||||
mm_positions=mm_positions,
|
||||
)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
assert num_tokens == len(requests[int(req_id)].prompt_token_ids)
|
||||
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
self.assertEqual(len(output.scheduled_encoder_inputs), len(requests))
|
||||
for req_id, encoder_input in output.scheduled_encoder_inputs.items():
|
||||
assert len(encoder_input) == 1
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_concurrent_partial_prefills_schedule(self):
|
||||
'''Test concurrent partial prefills scheduling.
|
||||
total requests = 10, every request has 10 token.
|
||||
while set long_prefill_token_threshold = 1, scheduler can
|
||||
only schedule max_long_partial_prefills long request.
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
scheduler.scheduler_config.max_long_partial_prefills = 2
|
||||
scheduler.scheduler_config.long_prefill_token_threshold = 1
|
||||
requests = create_requests(num_requests=10, num_tokens=20)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs),
|
||||
scheduler.scheduler_config.max_long_partial_prefills)
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
|
||||
def test_schedule_enable_prefix_caching(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
@@ -304,69 +344,34 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]
|
||||
], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -391,67 +396,35 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -475,67 +448,35 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id:
|
||||
[10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
@@ -556,52 +497,27 @@ class TestAscendScheduler(TestBase):
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
@@ -652,23 +568,13 @@ class TestAscendScheduler(TestBase):
|
||||
512)
|
||||
|
||||
# Model output of the first request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output0,
|
||||
model_runner_output)
|
||||
@@ -678,23 +584,13 @@ class TestAscendScheduler(TestBase):
|
||||
# request is still running.
|
||||
scheduler.schedule()
|
||||
# Model output of the second request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output1,
|
||||
model_runner_output)
|
||||
@@ -746,29 +642,19 @@ class TestAscendScheduler(TestBase):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
||||
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
spec_token_ids=spec_tokens,
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
@@ -804,23 +690,14 @@ class TestAscendScheduler(TestBase):
|
||||
else:
|
||||
self.assertNotIn(req_id,
|
||||
output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
@@ -896,3 +773,34 @@ class TestAscendScheduler(TestBase):
|
||||
|
||||
# Confirm no memory leak.
|
||||
self.assert_scheduler_empty(scheduler)
|
||||
|
||||
def test_scheduler_with_pd_transfer(self):
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.phase = "prefill"
|
||||
requests = create_requests(num_requests=32)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# 1st iteration, move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
first_iter_prefilled_req_num = len(scheduler.running)
|
||||
self.assertEqual(len(scheduler_output.scheduled_new_reqs),
|
||||
scheduler.max_num_running_reqs)
|
||||
self.assertEqual(scheduler_output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(scheduler_output.finished_req_ids), 0)
|
||||
|
||||
# 2nd iteration, move 16 prefilled requests to finished_prefill_reqs
|
||||
# and move 16 requests from waiting to running for prefill
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(len(scheduler.finished_prefill_reqs),
|
||||
first_iter_prefilled_req_num)
|
||||
|
||||
# 3rd iteration, all requests prefilled, change scheduler phase to decode
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
self.assertEqual(scheduler.phase, "decode")
|
||||
|
||||
Reference in New Issue
Block a user