Refector prepare_inputs in model_runner_v1.py (#2750)

### What this PR does / why we need it?
Refector prepare_inputs in model_runner_v1.py for more easy read.

### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
PASS CI

- vLLM version: v0.10.1.1
- vLLM main:
e599e2c65e

---------

Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
This commit is contained in:
TaoYu Chen
2025-09-08 10:45:23 +08:00
committed by GitHub
parent c735bb0941
commit dd087effcc

View File

@@ -880,6 +880,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
mm_embeds.append(mm_embeds_item)
return mm_embeds
def _get_cumsum_and_arange(
self,
num_tokens: np.ndarray,
cumsum_dtype: Optional[np.dtype] = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Get the cumulative sum and batched arange of the given array.
# E.g., [2, 5, 3] -> ([2, 7, 10], [0, 1, 0, 1, 2, 3, 4, 0, 1, 2])
# Equivalent to but faster than:
# np.concatenate([np.arange(n) for n in num_tokens])
"""
# Step 1. [2, 5, 3] -> [2, 7, 10]
cu_num_tokens = np.cumsum(num_tokens, dtype=cumsum_dtype)
total_num_tokens = cu_num_tokens[-1]
# Step 2. [2, 7, 10] -> [0, 0, 2, 2, 2, 2, 2, 7, 7, 7]
cumsums_offsets = np.repeat(cu_num_tokens - num_tokens, num_tokens)
# Step 3. [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange = self.arange_np[:total_num_tokens] - cumsums_offsets
return cu_num_tokens, arange
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
@@ -901,17 +921,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.block_table.commit_block_table(num_reqs)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = np.empty(num_reqs, dtype=np.int32)
num_valid_tokens = np.empty(num_reqs, dtype=np.int32)
max_num_scheduled_tokens = 0
for i, req_id in enumerate(self.input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens[i] = num_tokens
num_valid_tokens[i] = num_tokens - \
len(scheduler_output.scheduled_spec_decode_tokens.get(req_id, []))
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
for num_tokens, i in zip(tokens, req_ids)
],
dtype=np.int32)
if (self.use_aclgraph and total_num_scheduled_tokens
<= self.aclgraph_batch_sizes[-1]):
@@ -952,13 +971,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.lora_config:
self.set_active_loras(self.input_batch, num_scheduled_tokens)
# Prepare positions
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
req_indices = np.repeat(self.arange_np[:num_reqs],
num_scheduled_tokens)
cu_num_tokens = np.cumsum(num_scheduled_tokens)
cumsums_offsets = np.repeat(cu_num_tokens - num_scheduled_tokens,
num_scheduled_tokens)
arange = self.arange_np[:total_num_scheduled_tokens] - cumsums_offsets
# cu_num_tokens: [2, 5, 3] -> [2, 7, 10]
# arange: [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
cu_num_tokens, arange = self._get_cumsum_and_arange(
num_scheduled_tokens)
positions_np = self.positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
@@ -975,50 +996,73 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.mrope_positions_cpu[:, :total_num_scheduled_tokens],
non_blocking=True)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
# Prepare input_ids.
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
block_table_indices = (req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size)
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
block_offsets = positions_np % self.block_size
np.add(block_numbers * self.block_size,
block_offsets,
out=self.slot_mapping_np[:total_num_scheduled_tokens])
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
# Prepare some information for building Attention-Metadata
# Compute and commit slot mapping
self.input_batch.block_table.compute_slot_mapping(
req_indices, positions_np)
self.input_batch.block_table.commit_slot_mapping(
total_num_scheduled_tokens)
self.slot_mapping_cpu[:total_num_scheduled_tokens].copy_(
self.input_batch.block_table[0].
slot_mapping_cpu[:total_num_scheduled_tokens])
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
self.query_start_loc[:num_reqs + 1].copy_(
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
self.query_start_loc[num_reqs + 1:].fill_(-1)
self.seq_lens[num_reqs:].fill_(0)
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
# Make Attention metadata
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
self.attn_state = attn_state # type: ignore
self.with_prefill = with_prefill
self.num_tokens_across_dp = num_tokens_across_dp
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
# Make AscendCommonAttentionMetadata
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=self.query_start_loc[:num_reqs + 1],
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
@@ -1044,19 +1088,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.model_config.use_mla:
attn_metadata.num_input_tokens = num_input_tokens
# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
# _prepare_inputs may reorder the batch, so we must gather multi
# modal outputs after that to ensure the correct order
# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)