[Feature] model_runner refactor (#4764)

### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner 

### Does this PR introduce _any_ user-facing change?
NO

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
zhenwenqi2024
2025-12-12 17:27:09 +08:00
committed by GitHub
parent 5b12c068f9
commit f708d919f8
10 changed files with 676 additions and 1815 deletions

View File

@@ -114,6 +114,7 @@ class InputBatch:
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
logitsprocs_need_output_token_ids: bool = False,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
@@ -143,10 +144,11 @@ class InputBatch:
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len),
device="cpu",
dtype=bool,
pin_memory=False)
self.is_token_ids = self.is_token_ids_tensor.numpy()
# Store prompt embeddings per request to avoid OOM from large upfront
# allocation if max_model_len is big.
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
@@ -299,6 +301,11 @@ class InputBatch:
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
# Store last speculative tokens for sampler.
self.spec_token_ids: list[list[int]] = [[]
for _ in range(max_num_reqs)]
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
@@ -306,9 +313,14 @@ class InputBatch:
self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids: torch.Tensor | None = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
self.prev_req_id_to_index: dict[str, int] | None = None
# These are used to update output_token_ids with real sampled
# ids from prior step, if required by current sampling params
# (e.g. penalties).
self.sampled_token_ids_cpu: torch.Tensor | None = None
self.async_copy_ready_event: torch.Event | None = None
@property
def req_ids(self) -> list[str]:
@@ -350,9 +362,11 @@ class InputBatch:
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
self.spec_token_ids.append([])
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.spec_token_ids[req_index].clear()
self.req_id_to_index[req_id] = req_index
@@ -496,6 +510,21 @@ class InputBatch:
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.spec_token_ids[req_index].clear()
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
lora_req_ids = self.lora_id_to_request_ids[lora_id]
lora_req_ids.discard(req_id)
if not lora_req_ids:
del self.lora_id_to_request_ids[lora_id]
del self.lora_id_to_lora_request[lora_id]
self.request_lora_mapping[req_index] = 0
if self.is_pooling_model:
self.pooling_params.pop(req_id, None)
return req_index
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
@@ -510,6 +539,8 @@ class InputBatch:
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
if self.prev_req_id_to_index is not None:
self.prev_req_id_to_index.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
@@ -538,6 +569,10 @@ class InputBatch:
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
self.spec_token_ids[i2],
self.spec_token_ids[i1],
)
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
@@ -629,6 +664,7 @@ class InputBatch:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
self.spec_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
@@ -662,6 +698,16 @@ class InputBatch:
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
if last_req_index != empty_index:
(
self.spec_token_ids[last_req_index],
self.spec_token_ids[empty_index],
) = (
self.spec_token_ids[empty_index],
self.spec_token_ids[last_req_index],
)
self.spec_token_ids[last_req_index].clear()
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
@@ -714,6 +760,7 @@ class InputBatch:
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
del self.spec_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
@@ -787,6 +834,7 @@ class InputBatch:
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
@@ -848,6 +896,53 @@ class InputBatch:
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
def set_async_sampled_token_ids(
self,
sampled_token_ids_cpu: torch.Tensor,
async_copy_ready_event: torch.Event,
) -> None:
"""
In async scheduling case, store ref to sampled_token_ids_cpu
tensor and corresponding copy-ready event. Used to repair
output_token_ids prior to sampling, if needed by logits processors.
"""
if self.sampling_metadata.output_token_ids:
self.sampled_token_ids_cpu = sampled_token_ids_cpu
self.async_copy_ready_event = async_copy_ready_event
else:
self.sampled_token_ids_cpu = None
self.async_copy_ready_event = None
def update_async_output_token_ids(self) -> None:
"""
In async scheduling case, update output_token_ids in sampling metadata
from prior steps sampled token ids once they've finished copying to CPU.
This is called right before they are needed by the logits processors.
"""
output_token_ids = self.sampling_metadata.output_token_ids
if self.sampled_token_ids_cpu is None or not output_token_ids:
# Output token ids not needed or not async scheduling.
return
assert self.prev_req_id_to_index is not None
sampled_token_ids = None
for index, req_id in enumerate(self.req_ids):
prev_index = self.prev_req_id_to_index.get(req_id)
if prev_index is None:
continue
req_output_token_ids = output_token_ids[index]
if not req_output_token_ids or req_output_token_ids[-1] != -1:
# Final output id is not a placeholder, some tokens must have
# been discarded after a kv-load failure.
continue
if sampled_token_ids is None:
assert self.async_copy_ready_event is not None
self.async_copy_ready_event.synchronize()
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
-1).tolist()
# Replace placeholder token id with actual sampled id.
req_output_token_ids[-1] = sampled_token_ids[prev_index]
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)