[Feature] model_runner refactor (#4764)
### What this PR does / why we need it?
refactor npu_modelrunner, we should be close to gpu_modelrunner
### Does this PR introduce _any_ user-facing change?
NO
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Signed-off-by: zhenwenqi2024 <155598497+zhenwenqi2024@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
|
||||
|
||||
class BlockTable:
|
||||
@@ -76,32 +77,14 @@ class BlockTable:
|
||||
duplicate_size = 1
|
||||
if self.pcp_world_size > 1:
|
||||
duplicate_size += num_speculative_tokens
|
||||
self.block_table = torch.zeros(
|
||||
(max_num_reqs * duplicate_size, logical_table_size),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(max_num_reqs * duplicate_size, logical_table_size),
|
||||
device="cpu",
|
||||
dtype=torch.int32,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
self.block_table_np = self.block_table_cpu.numpy()
|
||||
self.block_table = self._make_buffer(max_num_reqs * duplicate_size,
|
||||
logical_table_size,
|
||||
dtype=torch.int32)
|
||||
self.num_blocks_per_row = np.zeros(max_num_reqs, dtype=np.int32)
|
||||
|
||||
self.slot_mapping_cpu = torch.zeros(
|
||||
self.slot_mapping = self._make_buffer(
|
||||
self.max_num_batched_tokens +
|
||||
2 * self.pcp_world_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
pin_memory=self.pin_memory)
|
||||
self.slot_mapping_np = self.slot_mapping_cpu.numpy()
|
||||
self.slot_mapping = torch.zeros(
|
||||
self.max_num_batched_tokens +
|
||||
2 * self.pcp_world_size * self.max_num_reqs,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
dtype=torch.int32)
|
||||
|
||||
self.kernel_sizes = kernel_sizes
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
@@ -120,7 +103,7 @@ class BlockTable:
|
||||
num_blocks = len(block_ids)
|
||||
start = self.num_blocks_per_row[row_idx]
|
||||
|
||||
self.block_table_np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.block_table.np[row_idx, start:start + num_blocks] = block_ids
|
||||
self.num_blocks_per_row[row_idx] += num_blocks
|
||||
|
||||
def add_row(self, block_ids: list[int], row_idx: int) -> None:
|
||||
@@ -129,7 +112,7 @@ class BlockTable:
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
num_blocks = self.num_blocks_per_row[src]
|
||||
self.block_table_np[tgt, :num_blocks] = self.block_table_np[
|
||||
self.block_table.np[tgt, :num_blocks] = self.block_table.np[
|
||||
src, :num_blocks]
|
||||
self.num_blocks_per_row[tgt] = num_blocks
|
||||
|
||||
@@ -139,7 +122,7 @@ class BlockTable:
|
||||
self.num_blocks_per_row[src] = num_blocks_tgt
|
||||
self.num_blocks_per_row[tgt] = num_blocks_src
|
||||
|
||||
self.block_table_np[[src, tgt]] = self.block_table_np[[tgt, src]]
|
||||
self.block_table.np[[src, tgt]] = self.block_table.np[[tgt, src]]
|
||||
|
||||
def compute_slot_mapping(self, req_indices: np.ndarray,
|
||||
positions: np.ndarray) -> None:
|
||||
@@ -171,7 +154,7 @@ class BlockTable:
|
||||
self.blocks_per_phys_block +
|
||||
logical_block_idx)
|
||||
|
||||
block_numbers = self.block_table_np.ravel()[block_table_indices]
|
||||
block_numbers = self.block_table.np.ravel()[block_table_indices]
|
||||
# Use virtual_block_size for mask calculation, which marks local
|
||||
# tokens.
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
@@ -186,7 +169,7 @@ class BlockTable:
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
# Write final slots, use -1 for not-local
|
||||
self.slot_mapping_np[:req_indices.shape[0]] = np.where(
|
||||
self.slot_mapping.np[:req_indices.shape[0]] = np.where(
|
||||
mask, slot_mapping, -1)
|
||||
else:
|
||||
assert self.kernel_sizes is not None
|
||||
@@ -203,24 +186,22 @@ class BlockTable:
|
||||
req_indices * self.max_num_blocks_per_req *
|
||||
self.blocks_per_phys_block + logical_block_idx)
|
||||
|
||||
block_numbers = self.block_table_np.ravel(
|
||||
block_numbers = self.block_table.np.ravel(
|
||||
)[block_table_indices]
|
||||
block_offsets = positions % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.slot_mapping_np[:req_indices.shape[0]])
|
||||
out=self.slot_mapping.np[:req_indices.shape[0]])
|
||||
|
||||
def commit_block_table(self, num_reqs: int) -> None:
|
||||
self.block_table[:num_reqs].copy_(self.block_table_cpu[:num_reqs],
|
||||
non_blocking=True)
|
||||
self.block_table.copy_to_gpu(num_reqs)
|
||||
|
||||
def commit_slot_mapping(self, num_tokens: int) -> None:
|
||||
self.slot_mapping[:num_tokens].copy_(
|
||||
self.slot_mapping_cpu[:num_tokens], non_blocking=True)
|
||||
self.slot_mapping.copy_to_gpu(num_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.block_table.fill_(0)
|
||||
self.block_table_cpu.fill_(0)
|
||||
self.block_table.cpu.fill_(0)
|
||||
|
||||
def _convert_physical_to_logical_blocks(
|
||||
self, physical_blocks: np.ndarray) -> np.ndarray:
|
||||
@@ -243,15 +224,22 @@ class BlockTable:
|
||||
|
||||
def get_device_tensor(self) -> torch.Tensor:
|
||||
"""Returns the device tensor of the block table."""
|
||||
return self.block_table
|
||||
return self.block_table.gpu
|
||||
|
||||
def get_cpu_tensor(self) -> torch.Tensor:
|
||||
"""Returns the CPU tensor of the block table."""
|
||||
return self.block_table_cpu
|
||||
return self.block_table.cpu
|
||||
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
return self.block_table.np
|
||||
|
||||
def _make_buffer(self, *size: int | torch.SymInt,
|
||||
dtype: torch.dtype) -> CpuGpuBuffer:
|
||||
return CpuGpuBuffer(*size,
|
||||
dtype=dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory)
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -114,6 +114,7 @@ class InputBatch:
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
logitsprocs_need_output_token_ids: bool = False,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
num_speculative_tokens: int = 0,
|
||||
@@ -143,10 +144,11 @@ class InputBatch:
|
||||
pin_memory=False,
|
||||
)
|
||||
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
|
||||
self.is_token_ids = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
self.is_token_ids_tensor = torch.zeros((max_num_reqs, max_model_len),
|
||||
device="cpu",
|
||||
dtype=bool,
|
||||
pin_memory=False)
|
||||
self.is_token_ids = self.is_token_ids_tensor.numpy()
|
||||
# Store prompt embeddings per request to avoid OOM from large upfront
|
||||
# allocation if max_model_len is big.
|
||||
# Maps req_index -> tensor of shape (num_prompt_tokens, hidden_size)
|
||||
@@ -299,6 +301,11 @@ class InputBatch:
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
self.logitsprocs_need_output_token_ids = logitsprocs_need_output_token_ids
|
||||
|
||||
# Store last speculative tokens for sampler.
|
||||
self.spec_token_ids: list[list[int]] = [[]
|
||||
for _ in range(max_num_reqs)]
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
@@ -306,9 +313,14 @@ class InputBatch:
|
||||
self.pooling_params: dict[str, PoolingParams] = {}
|
||||
|
||||
# Cached reference to the GPU tensor of previously sampled tokens
|
||||
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
self.prev_sampled_token_ids: torch.Tensor | None = None
|
||||
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
|
||||
self.prev_req_id_to_index: Optional[dict[str, int]] = None
|
||||
self.prev_req_id_to_index: dict[str, int] | None = None
|
||||
# These are used to update output_token_ids with real sampled
|
||||
# ids from prior step, if required by current sampling params
|
||||
# (e.g. penalties).
|
||||
self.sampled_token_ids_cpu: torch.Tensor | None = None
|
||||
self.async_copy_ready_event: torch.Event | None = None
|
||||
|
||||
@property
|
||||
def req_ids(self) -> list[str]:
|
||||
@@ -350,9 +362,11 @@ class InputBatch:
|
||||
if req_index == len(self._req_ids):
|
||||
self._req_ids.append(req_id)
|
||||
self.req_output_token_ids.append(request.output_token_ids)
|
||||
self.spec_token_ids.append([])
|
||||
else:
|
||||
self._req_ids[req_index] = req_id
|
||||
self.req_output_token_ids[req_index] = request.output_token_ids
|
||||
self.spec_token_ids[req_index].clear()
|
||||
|
||||
self.req_id_to_index[req_id] = req_index
|
||||
|
||||
@@ -496,6 +510,21 @@ class InputBatch:
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
self.spec_token_ids[req_index].clear()
|
||||
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
lora_req_ids = self.lora_id_to_request_ids[lora_id]
|
||||
lora_req_ids.discard(req_id)
|
||||
if not lora_req_ids:
|
||||
del self.lora_id_to_request_ids[lora_id]
|
||||
del self.lora_id_to_lora_request[lora_id]
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
if self.is_pooling_model:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
self.greedy_reqs.discard(req_id)
|
||||
self.random_reqs.discard(req_id)
|
||||
@@ -510,6 +539,8 @@ class InputBatch:
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
|
||||
|
||||
if self.prev_req_id_to_index is not None:
|
||||
self.prev_req_id_to_index.pop(req_id, None)
|
||||
# LoRA
|
||||
lora_id = self.request_lora_mapping[req_index]
|
||||
if lora_id != 0:
|
||||
@@ -538,6 +569,10 @@ class InputBatch:
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
self.spec_token_ids[i1], self.spec_token_ids[i2] = (
|
||||
self.spec_token_ids[i2],
|
||||
self.spec_token_ids[i1],
|
||||
)
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
@@ -629,6 +664,7 @@ class InputBatch:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
self.req_output_token_ids.clear()
|
||||
self.spec_token_ids.clear()
|
||||
return
|
||||
|
||||
# NOTE(woosuk): This function assumes that the empty_req_indices
|
||||
@@ -662,6 +698,16 @@ class InputBatch:
|
||||
self.req_output_token_ids[last_req_index] = None
|
||||
self.req_id_to_index[req_id] = empty_index
|
||||
|
||||
if last_req_index != empty_index:
|
||||
(
|
||||
self.spec_token_ids[last_req_index],
|
||||
self.spec_token_ids[empty_index],
|
||||
) = (
|
||||
self.spec_token_ids[empty_index],
|
||||
self.spec_token_ids[last_req_index],
|
||||
)
|
||||
self.spec_token_ids[last_req_index].clear()
|
||||
|
||||
num_tokens = self.num_tokens[last_req_index]
|
||||
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
|
||||
last_req_index, :num_tokens]
|
||||
@@ -714,6 +760,7 @@ class InputBatch:
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
del self.spec_token_ids[num_reqs:]
|
||||
|
||||
def refresh_metadata(self):
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
@@ -787,6 +834,7 @@ class InputBatch:
|
||||
presence_penalties=self.presence_penalties[:num_reqs],
|
||||
repetition_penalties=self.repetition_penalties[:num_reqs],
|
||||
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
|
||||
spec_token_ids=cast(list[list[int]], self.spec_token_ids),
|
||||
no_penalties=self.no_penalties,
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
@@ -848,6 +896,53 @@ class InputBatch:
|
||||
|
||||
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
|
||||
|
||||
def set_async_sampled_token_ids(
|
||||
self,
|
||||
sampled_token_ids_cpu: torch.Tensor,
|
||||
async_copy_ready_event: torch.Event,
|
||||
) -> None:
|
||||
"""
|
||||
In async scheduling case, store ref to sampled_token_ids_cpu
|
||||
tensor and corresponding copy-ready event. Used to repair
|
||||
output_token_ids prior to sampling, if needed by logits processors.
|
||||
"""
|
||||
if self.sampling_metadata.output_token_ids:
|
||||
self.sampled_token_ids_cpu = sampled_token_ids_cpu
|
||||
self.async_copy_ready_event = async_copy_ready_event
|
||||
else:
|
||||
self.sampled_token_ids_cpu = None
|
||||
self.async_copy_ready_event = None
|
||||
|
||||
def update_async_output_token_ids(self) -> None:
|
||||
"""
|
||||
In async scheduling case, update output_token_ids in sampling metadata
|
||||
from prior steps sampled token ids once they've finished copying to CPU.
|
||||
This is called right before they are needed by the logits processors.
|
||||
"""
|
||||
output_token_ids = self.sampling_metadata.output_token_ids
|
||||
if self.sampled_token_ids_cpu is None or not output_token_ids:
|
||||
# Output token ids not needed or not async scheduling.
|
||||
return
|
||||
|
||||
assert self.prev_req_id_to_index is not None
|
||||
sampled_token_ids = None
|
||||
for index, req_id in enumerate(self.req_ids):
|
||||
prev_index = self.prev_req_id_to_index.get(req_id)
|
||||
if prev_index is None:
|
||||
continue
|
||||
req_output_token_ids = output_token_ids[index]
|
||||
if not req_output_token_ids or req_output_token_ids[-1] != -1:
|
||||
# Final output id is not a placeholder, some tokens must have
|
||||
# been discarded after a kv-load failure.
|
||||
continue
|
||||
if sampled_token_ids is None:
|
||||
assert self.async_copy_ready_event is not None
|
||||
self.async_copy_ready_event.synchronize()
|
||||
sampled_token_ids = self.sampled_token_ids_cpu.squeeze(
|
||||
-1).tolist()
|
||||
# Replace placeholder token id with actual sampled id.
|
||||
req_output_token_ids[-1] = sampled_token_ids[prev_index]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_id_to_index)
|
||||
|
||||
Reference in New Issue
Block a user