[Feature] model_runner refactor (#4764)
### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner
### Does this PR introduce _any_ user-facing change?
NO
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -114,6 +114,7 @@ class InputBatch:
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
logitsprocs_need_output_token_ids: bool = False,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
@@ -143,10 +144,11 @@ class InputBatch:
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
self.is_token_ids = self.is_token_ids_tensor.numpy()
|
||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
@@ -299,6 +301,11 @@ class InputBatch:
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
|
||||
|
||||
# Store last speculative tokens for sampler.
|
||||
self.spec_token_ids: list[list[int]] = [[]
|
||||
for _ in range(max_num_reqs)]
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
@@ -306,9 +313,14 @@ class InputBatch:
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids: torch.Tensor | None = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
self.prev_req_id_to_index: dict[str, int] | None = None
|
||||
# These are used to update output_token_ids with real sampled
|
||||
# ids from prior step, if required by current sampling params
|
||||
# (e.g. penalties).
|
||||
self.sampled_token_ids_cpu: torch.Tensor | None = None
|
||||
self.async_copy_ready_event: torch.Event | None = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
@@ -350,9 +362,11 @@ class InputBatch:
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
self.spec_token_ids.append([])
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
self.spec_token_ids[req_index].clear()
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
@@ -496,6 +510,21 @@ class InputBatch:
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
self.spec_token_ids[req_index].clear()
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||
lora_req_ids.discard(req_id)
|
||||
if not lora_req_ids:
|
||||
del self.lora_id_to_request_ids[lora_id]
|
||||
del self.lora_id_to_lora_request[lora_id]
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
if self.is_pooling_model:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
@@ -510,6 +539,8 @@ class InputBatch:
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
if self.prev_req_id_to_index is not None:
|
||||
self.prev_req_id_to_index.pop(req_id, None)
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
@@ -538,6 +569,10 @@ class InputBatch:
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
|
||||
self.spec_token_ids[i2],
|
||||
self.spec_token_ids[i1],
|
||||
)
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
@@ -629,6 +664,7 @@ class InputBatch:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
self.spec_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
@@ -662,6 +698,16 @@ class InputBatch:
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
if last_req_index != empty_index:
|
||||
(
|
||||
self.spec_token_ids[last_req_index],
|
||||
self.spec_token_ids[empty_index],
|
||||
) = (
|
||||
self.spec_token_ids[empty_index],
|
||||
self.spec_token_ids[last_req_index],
|
||||
)
|
||||
self.spec_token_ids[last_req_index].clear()
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
@@ -714,6 +760,7 @@ class InputBatch:
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
del self.spec_token_ids[num_reqs:]
|
||||
|
||||
def refresh_metadata(self):
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
@@ -787,6 +834,7 @@ class InputBatch:
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
@@ -848,6 +896,53 @@ class InputBatch:
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
def set_async_sampled_token_ids(
|
||||
self,
|
||||
sampled_token_ids_cpu: torch.Tensor,
|
||||
async_copy_ready_event: torch.Event,
|
||||
) -> None:
|
||||
"""
|
||||
In async scheduling case, store ref to sampled_token_ids_cpu
|
||||
tensor and corresponding copy-ready event. Used to repair
|
||||
output_token_ids prior to sampling, if needed by logits processors.
|
||||
"""
|
||||
if self.sampling_metadata.output_token_ids:
|
||||
self.sampled_token_ids_cpu = sampled_token_ids_cpu
|
||||
self.async_copy_ready_event = async_copy_ready_event
|
||||
else:
|
||||
self.sampled_token_ids_cpu = None
|
||||
self.async_copy_ready_event = None
|
||||
|
||||
def update_async_output_token_ids(self) -> None:
|
||||
"""
|
||||
In async scheduling case, update output_token_ids in sampling metadata
|
||||
from prior steps sampled token ids once they've finished copying to CPU.
|
||||
This is called right before they are needed by the logits processors.
|
||||
"""
|
||||
output_token_ids = self.sampling_metadata.output_token_ids
|
||||
if self.sampled_token_ids_cpu is None or not output_token_ids:
|
||||
# Output token ids not needed or not async scheduling.
|
||||
return
|
||||
|
||||
assert self.prev_req_id_to_index is not None
|
||||
sampled_token_ids = None
|
||||
for index, req_id in enumerate(self.req_ids):
|
||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||
if prev_index is None:
|
||||
continue
|
||||
req_output_token_ids = output_token_ids[index]
|
||||
if not req_output_token_ids or req_output_token_ids[-1] != -1:
|
||||
# Final output id is not a placeholder, some tokens must have
|
||||
# been discarded after a kv-load failure.
|
||||
continue
|
||||
if sampled_token_ids is None:
|
||||
assert self.async_copy_ready_event is not None
|
||||
self.async_copy_ready_event.synchronize()
|
||||
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
|
||||
-1).tolist()
|
||||
# Replace placeholder token id with actual sampled id.
|
||||
req_output_token_ids[-1] = sampled_token_ids[prev_index]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
Reference in New Issue
Block a user