[ModelRunner] Use shared CachedRequestData cross request to fix ci (#1546)
### What this PR does / why we need it?
This PR (adapted from
2863befce3)
updates the CachedRequestData definition to use a single instance shared
across all requests in a batch, instead of creating a new instance per
request.
Found ci boken by the vllm's model_runner change: `ERROR 07-01 09:53:53
[core.py:521] TypeError: 'CachedRequestData' object is not iterable`,
Modify the model_runner to fix it.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
pass ci will verify this.
---------
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
@@ -201,7 +201,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
|
|||||||
# Test initial scheduling
|
# Test initial scheduling
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
else:
|
||||||
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
# Verify all requests are scheduled.
|
# Verify all requests are scheduled.
|
||||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
@@ -238,7 +241,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == 3
|
assert len(output.scheduled_new_reqs) == 3
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
else:
|
||||||
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
|
|
||||||
# The first request is scheduled partially - 400.
|
# The first request is scheduled partially - 400.
|
||||||
@@ -268,7 +274,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
output1 = scheduler.schedule()
|
output1 = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(output1.scheduled_new_reqs) == 0
|
assert len(output1.scheduled_new_reqs) == 0
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
assert len(output1.scheduled_cached_reqs) == 3
|
assert len(output1.scheduled_cached_reqs) == 3
|
||||||
|
else:
|
||||||
|
assert output1.scheduled_cached_reqs.num_reqs == 3
|
||||||
assert len(output1.finished_req_ids) == 0
|
assert len(output1.finished_req_ids) == 0
|
||||||
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
|
||||||
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
|
||||||
@@ -292,7 +301,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
output2 = scheduler.schedule()
|
output2 = scheduler.schedule()
|
||||||
assert len(scheduler.running) == 3
|
assert len(scheduler.running) == 3
|
||||||
assert len(output2.scheduled_new_reqs) == 0
|
assert len(output2.scheduled_new_reqs) == 0
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
assert len(output2.scheduled_cached_reqs) == 3
|
assert len(output2.scheduled_cached_reqs) == 3
|
||||||
|
else:
|
||||||
|
assert output2.scheduled_cached_reqs.num_reqs == 3
|
||||||
assert len(output2.finished_req_ids) == 0
|
assert len(output2.finished_req_ids) == 0
|
||||||
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
|
||||||
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
|
||||||
@@ -762,7 +774,6 @@ def assert_scheduler_empty(scheduler: AscendScheduler):
|
|||||||
assert len(scheduler.waiting) == 0
|
assert len(scheduler.waiting) == 0
|
||||||
assert len(scheduler.running) == 0
|
assert len(scheduler.running) == 0
|
||||||
assert len(scheduler.finished_req_ids) == 0
|
assert len(scheduler.finished_req_ids) == 0
|
||||||
assert len(scheduler._cached_reqs_data) == 0
|
|
||||||
|
|
||||||
# EncoderCacheManager.
|
# EncoderCacheManager.
|
||||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||||
|
|||||||
@@ -192,7 +192,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
|
|||||||
# Test initial scheduling
|
# Test initial scheduling
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
assert len(output.scheduled_cached_reqs) == 0
|
assert len(output.scheduled_cached_reqs) == 0
|
||||||
|
else:
|
||||||
|
assert output.scheduled_cached_reqs.num_reqs == 0
|
||||||
assert len(output.finished_req_ids) == 0
|
assert len(output.finished_req_ids) == 0
|
||||||
# Verify all requests are scheduled.
|
# Verify all requests are scheduled.
|
||||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ from vllm.v1.outputs import ModelRunnerOutput
|
|||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
|
||||||
class AscendScheduler(Scheduler):
|
class AscendScheduler(Scheduler):
|
||||||
"""This Scheduler extends vllm's original v1 scheduler
|
"""This Scheduler extends vllm's original v1 scheduler
|
||||||
@@ -364,6 +366,7 @@ class AscendScheduler(Scheduler):
|
|||||||
req_to_new_block_ids[req.request_id])
|
req_to_new_block_ids[req.request_id])
|
||||||
for req in scheduled_new_reqs
|
for req in scheduled_new_reqs
|
||||||
]
|
]
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
resumed_reqs_data = [
|
resumed_reqs_data = [
|
||||||
self._make_cached_request_data(
|
self._make_cached_request_data(
|
||||||
req,
|
req,
|
||||||
@@ -382,9 +385,17 @@ class AscendScheduler(Scheduler):
|
|||||||
resumed_from_preemption=False,
|
resumed_from_preemption=False,
|
||||||
) for req in scheduled_running_reqs
|
) for req in scheduled_running_reqs
|
||||||
]
|
]
|
||||||
|
scheduled_cached_reqs = resumed_reqs_data + running_reqs_data
|
||||||
|
else:
|
||||||
|
cached_reqs_data = self._make_cached_request_data(
|
||||||
|
scheduled_running_reqs, scheduled_resumed_reqs,
|
||||||
|
num_scheduled_tokens, scheduled_spec_decode_tokens,
|
||||||
|
req_to_new_block_ids)
|
||||||
|
scheduled_cached_reqs = cached_reqs_data
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(
|
scheduler_output = SchedulerOutput(
|
||||||
scheduled_new_reqs=new_reqs_data,
|
scheduled_new_reqs=new_reqs_data,
|
||||||
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
|
scheduled_cached_reqs=scheduled_cached_reqs,
|
||||||
num_scheduled_tokens=num_scheduled_tokens,
|
num_scheduled_tokens=num_scheduled_tokens,
|
||||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||||
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,
|
||||||
|
|||||||
@@ -456,6 +456,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_ids_to_add.append(req_id)
|
req_ids_to_add.append(req_id)
|
||||||
|
|
||||||
# Update the states of the running/resumed requests.
|
# Update the states of the running/resumed requests.
|
||||||
|
if vllm_version_is("0.9.1"):
|
||||||
for req_data in scheduler_output.scheduled_cached_reqs:
|
for req_data in scheduler_output.scheduled_cached_reqs:
|
||||||
req_id = req_data.req_id
|
req_id = req_data.req_id
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
@@ -470,7 +471,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_state.num_tokens)
|
req_state.num_tokens)
|
||||||
if num_new_tokens == 1:
|
if num_new_tokens == 1:
|
||||||
# Avoid slicing list in most common case.
|
# Avoid slicing list in most common case.
|
||||||
req_state.output_token_ids.append(req_data.new_token_ids[-1])
|
req_state.output_token_ids.append(
|
||||||
|
req_data.new_token_ids[-1])
|
||||||
elif num_new_tokens > 0:
|
elif num_new_tokens > 0:
|
||||||
req_state.output_token_ids.extend(
|
req_state.output_token_ids.extend(
|
||||||
req_data.new_token_ids[-num_new_tokens:])
|
req_data.new_token_ids[-num_new_tokens:])
|
||||||
@@ -501,15 +503,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
start_index = (len(req_state.block_ids) -
|
start_index = (len(req_state.block_ids) -
|
||||||
len(req_data.new_block_ids))
|
len(req_data.new_block_ids))
|
||||||
self.input_batch.block_table.append_row(req_data.new_block_ids,
|
self.input_batch.block_table.append_row(
|
||||||
req_index)
|
req_data.new_block_ids, req_index)
|
||||||
# Add new_token_ids to token_ids_cpu.
|
# Add new_token_ids to token_ids_cpu.
|
||||||
start_token_index = num_computed_tokens
|
start_token_index = num_computed_tokens
|
||||||
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
|
end_token_index = num_computed_tokens + len(
|
||||||
|
req_data.new_token_ids)
|
||||||
self.input_batch.token_ids_cpu[
|
self.input_batch.token_ids_cpu[
|
||||||
req_index,
|
req_index,
|
||||||
start_token_index:end_token_index] = req_data.new_token_ids
|
start_token_index:end_token_index] = req_data.new_token_ids
|
||||||
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
|
self.input_batch.num_tokens_no_spec[
|
||||||
|
req_index] = end_token_index
|
||||||
# Add spec_token_ids to token_ids_cpu.
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
req_id, ())
|
req_id, ())
|
||||||
@@ -517,7 +521,72 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
start_index = end_token_index
|
start_index = end_token_index
|
||||||
end_token_index += len(spec_token_ids)
|
end_token_index += len(spec_token_ids)
|
||||||
self.input_batch.token_ids_cpu[
|
self.input_batch.token_ids_cpu[
|
||||||
req_index, start_index:end_token_index] = spec_token_ids
|
req_index,
|
||||||
|
start_index:end_token_index] = spec_token_ids
|
||||||
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||||
|
self.input_batch.num_tokens[req_index] = end_token_index
|
||||||
|
else:
|
||||||
|
req_data = scheduler_output.scheduled_cached_reqs
|
||||||
|
for i, req_id in enumerate(req_data.req_ids):
|
||||||
|
req_state = self.requests[req_id]
|
||||||
|
num_computed_tokens = req_data.num_computed_tokens[i]
|
||||||
|
new_token_ids = req_data.new_token_ids[i]
|
||||||
|
new_block_ids = req_data.new_block_ids[i]
|
||||||
|
resumed_from_preemption = req_data.resumed_from_preemption[i]
|
||||||
|
|
||||||
|
req_state.num_computed_tokens = num_computed_tokens
|
||||||
|
# Add the sampled token(s) from the previous step (if any).
|
||||||
|
# This doesn't include "unverified" tokens like spec decode tokens.
|
||||||
|
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
|
||||||
|
req_state.num_tokens)
|
||||||
|
if num_new_tokens == 1:
|
||||||
|
# Avoid slicing list in most common case.
|
||||||
|
req_state.output_token_ids.append(new_token_ids[-1])
|
||||||
|
elif num_new_tokens > 0:
|
||||||
|
req_state.output_token_ids.extend(
|
||||||
|
new_token_ids[-num_new_tokens:])
|
||||||
|
# Update the block IDs.
|
||||||
|
if not resumed_from_preemption:
|
||||||
|
# Append the new blocks to the existing block IDs.
|
||||||
|
for block_ids, new_ids in zip( # type: ignore[call-overload]
|
||||||
|
req_state.block_ids, new_block_ids):
|
||||||
|
block_ids.extend(new_ids)
|
||||||
|
else:
|
||||||
|
# The request is resumed from preemption.
|
||||||
|
# Replace the existing block IDs with the new ones.
|
||||||
|
req_state.block_ids = new_block_ids
|
||||||
|
|
||||||
|
req_index = self.input_batch.req_id_to_index.get(req_id)
|
||||||
|
if req_index is None:
|
||||||
|
# The request is not in the persistent batch.
|
||||||
|
# The request was either preempted and resumed later, or was not
|
||||||
|
# scheduled in the previous step and needs to be added again.
|
||||||
|
req_ids_to_add.append(req_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Update the persistent batch.
|
||||||
|
self.input_batch.num_computed_tokens_cpu[req_index] = (
|
||||||
|
num_computed_tokens)
|
||||||
|
|
||||||
|
self.input_batch.block_table.append_row(
|
||||||
|
new_block_ids, req_index)
|
||||||
|
# Add new_token_ids to token_ids_cpu.
|
||||||
|
start_token_index = num_computed_tokens
|
||||||
|
end_token_index = num_computed_tokens + len(new_token_ids)
|
||||||
|
self.input_batch.token_ids_cpu[
|
||||||
|
req_index,
|
||||||
|
start_token_index:end_token_index] = new_token_ids
|
||||||
|
self.input_batch.num_tokens_no_spec[
|
||||||
|
req_index] = end_token_index
|
||||||
|
# Add spec_token_ids to token_ids_cpu.
|
||||||
|
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
|
||||||
|
req_id, ())
|
||||||
|
if spec_token_ids:
|
||||||
|
start_index = end_token_index
|
||||||
|
end_token_index += len(spec_token_ids)
|
||||||
|
self.input_batch.token_ids_cpu[
|
||||||
|
req_index,
|
||||||
|
start_index:end_token_index] = spec_token_ids
|
||||||
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
|
||||||
self.input_batch.num_tokens[req_index] = end_token_index
|
self.input_batch.num_tokens[req_index] = end_token_index
|
||||||
|
|
||||||
@@ -527,7 +596,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
# Add the new or resumed requests to the persistent batch.
|
# Add the new or resumed requests to the persistent batch.
|
||||||
# The smaller empty indices are filled first.
|
# The smaller empty indices are filled first.
|
||||||
removed_req_indices = sorted(removed_req_indices, reverse=True)
|
removed_req_indices.sort(reverse=True)
|
||||||
for req_id in req_ids_to_add:
|
for req_id in req_ids_to_add:
|
||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
if removed_req_indices:
|
if removed_req_indices:
|
||||||
|
|||||||
Reference in New Issue
Block a user