[ModelRunner] Use shared CachedRequestData cross request to fix ci (#1546)

### What this PR does / why we need it?

This PR (adapted from
2863befce3)
updates the CachedRequestData definition to use a single instance shared
across all requests in a batch, instead of creating a new instance per
request.

Found ci boken by the vllm's model_runner change: `ERROR 07-01 09:53:53
[core.py:521] TypeError: 'CachedRequestData' object is not iterable`,
Modify the model_runner to fix it.


### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
pass ci will verify this.

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: Yikun Jiang <yikunkero@gmail.com>
This commit is contained in:
Pleaplusone
2025-07-02 06:05:21 +08:00
committed by GitHub
parent 6db7dc2c85
commit 0e43813120
4 changed files with 179 additions and 85 deletions

View File

@@ -201,7 +201,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
if vllm_version_is("0.9.1"):
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
@@ -238,7 +241,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 3
assert len(output.scheduled_cached_reqs) == 0
if vllm_version_is("0.9.1"):
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# The first request is scheduled partially - 400.
@@ -268,7 +274,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output1 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output1.scheduled_new_reqs) == 0
assert len(output1.scheduled_cached_reqs) == 3
if vllm_version_is("0.9.1"):
assert len(output1.scheduled_cached_reqs) == 3
else:
assert output1.scheduled_cached_reqs.num_reqs == 3
assert len(output1.finished_req_ids) == 0
assert output1.num_scheduled_tokens[requests[0].request_id] == 400
assert output1.num_scheduled_tokens[requests[1].request_id] == 400
@@ -292,7 +301,10 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
output2 = scheduler.schedule()
assert len(scheduler.running) == 3
assert len(output2.scheduled_new_reqs) == 0
assert len(output2.scheduled_cached_reqs) == 3
if vllm_version_is("0.9.1"):
assert len(output2.scheduled_cached_reqs) == 3
else:
assert output2.scheduled_cached_reqs.num_reqs == 3
assert len(output2.finished_req_ids) == 0
assert output2.num_scheduled_tokens[requests[0].request_id] == 1
assert output2.num_scheduled_tokens[requests[1].request_id] == 1
@@ -762,7 +774,6 @@ def assert_scheduler_empty(scheduler: AscendScheduler):
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0

View File

@@ -192,7 +192,10 @@ def test_schedule(enable_prefix_caching: Optional[bool],
# Test initial scheduling
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert len(output.scheduled_cached_reqs) == 0
if vllm_version_is("0.9.1"):
assert len(output.scheduled_cached_reqs) == 0
else:
assert output.scheduled_cached_reqs.num_reqs == 0
assert len(output.finished_req_ids) == 0
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():

View File

@@ -32,6 +32,8 @@ from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
class AscendScheduler(Scheduler):
"""This Scheduler extends vllm's original v1 scheduler
@@ -364,27 +366,36 @@ class AscendScheduler(Scheduler):
req_to_new_block_ids[req.request_id])
for req in scheduled_new_reqs
]
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
if vllm_version_is("0.9.1"):
resumed_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=True,
) for req in scheduled_resumed_reqs
]
running_reqs_data = [
self._make_cached_request_data(
req,
num_scheduled_tokens[req.request_id],
len(scheduled_spec_decode_tokens.get(req.request_id, ())),
req_to_new_block_ids[req.request_id],
resumed_from_preemption=False,
) for req in scheduled_running_reqs
]
scheduled_cached_reqs = resumed_reqs_data + running_reqs_data
else:
cached_reqs_data = self._make_cached_request_data(
scheduled_running_reqs, scheduled_resumed_reqs,
num_scheduled_tokens, scheduled_spec_decode_tokens,
req_to_new_block_ids)
scheduled_cached_reqs = cached_reqs_data
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
scheduled_cached_reqs=resumed_reqs_data + running_reqs_data,
scheduled_cached_reqs=scheduled_cached_reqs,
num_scheduled_tokens=num_scheduled_tokens,
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens=scheduled_spec_decode_tokens,

View File

@@ -456,70 +456,139 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids_to_add.append(req_id)
# Update the states of the running/resumed requests.
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
if vllm_version_is("0.9.1"):
for req_data in scheduler_output.scheduled_cached_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
# Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(req_data.new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
# Update the cached states.
num_computed_tokens = req_data.num_computed_tokens
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens +
len(req_data.new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(
req_data.new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
req_data.new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not req_data.resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_block_ids in zip( # type: ignore[call-overload]
req_state.block_ids,
req_data.new_block_ids,
strict=True):
block_ids.extend(new_block_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = req_data.new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(req_data.new_block_ids,
req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(req_data.new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
start_index = (len(req_state.block_ids) -
len(req_data.new_block_ids))
self.input_batch.block_table.append_row(
req_data.new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(
req_data.new_token_ids)
self.input_batch.token_ids_cpu[
req_index, start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
req_index,
start_token_index:end_token_index] = req_data.new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
else:
req_data = scheduler_output.scheduled_cached_reqs
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
new_token_ids = req_data.new_token_ids[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
req_state.num_computed_tokens = num_computed_tokens
# Add the sampled token(s) from the previous step (if any).
# This doesn't include "unverified" tokens like spec decode tokens.
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
req_state.num_tokens)
if num_new_tokens == 1:
# Avoid slicing list in most common case.
req_state.output_token_ids.append(new_token_ids[-1])
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
# Update the block IDs.
if not resumed_from_preemption:
# Append the new blocks to the existing block IDs.
for block_ids, new_ids in zip( # type: ignore[call-overload]
req_state.block_ids, new_block_ids):
block_ids.extend(new_ids)
else:
# The request is resumed from preemption.
# Replace the existing block IDs with the new ones.
req_state.block_ids = new_block_ids
req_index = self.input_batch.req_id_to_index.get(req_id)
if req_index is None:
# The request is not in the persistent batch.
# The request was either preempted and resumed later, or was not
# scheduled in the previous step and needs to be added again.
req_ids_to_add.append(req_id)
continue
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.block_table.append_row(
new_block_ids, req_index)
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
# Add spec_token_ids to token_ids_cpu.
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
req_id, ())
if spec_token_ids:
start_index = end_token_index
end_token_index += len(spec_token_ids)
self.input_batch.token_ids_cpu[
req_index,
start_index:end_token_index] = spec_token_ids
# NOTE(woosuk): `num_tokens` here may include spec decode tokens.
self.input_batch.num_tokens[req_index] = end_token_index
# Check if the batch has changed. If not, we can skip copying the
# sampling metadata from CPU to GPU.
@@ -527,7 +596,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
removed_req_indices.sort(reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices: