Files
xc-llm-ascend/vllm_ascend/worker/npu_input_batch.py
Mengqing Cao cea0755b07 [1/N][Refactor] Refactor code to adapt with vllm main (#3612)
### What this PR does / why we need it?
This is the step 1 of refactoring code to adapt with vllm main, and this
pr aligned with
17c540a993

1. refactor deepseek to the latest code arch as of
17c540a993
 
2. bunches of fixes due to vllm changes
- Fix `AscendScheduler` `__post_init__`, caused by
https://github.com/vllm-project/vllm/pull/25075
- Fix `AscendScheduler` init got an unexpected arg `block_size`, caused
by https://github.com/vllm-project/vllm/pull/26296
- Fix `KVCacheManager` `get_num_common_prefix_blocks` arg, caused by
https://github.com/vllm-project/vllm/pull/23485
- Fix `MLAAttention` import,caused by
https://github.com/vllm-project/vllm/pull/25103
- Fix `SharedFusedMoE` import, caused by
https://github.com/vllm-project/vllm/pull/26145
- Fix `LazyLoader` improt, caused by
https://github.com/vllm-project/vllm/pull/27022
- Fix `vllm.utils.swap_dict_values` improt, caused by
https://github.com/vllm-project/vllm/pull/26990
- Fix `Backend` enum import, caused by
https://github.com/vllm-project/vllm/pull/25893
- Fix `CompilationLevel` renaming to `CompilationMode` issue introduced
by https://github.com/vllm-project/vllm/pull/26355
- Fix fused_moe ops, caused by
https://github.com/vllm-project/vllm/pull/24097
- Fix bert model because of `inputs_embeds`, caused by
https://github.com/vllm-project/vllm/pull/25922
- Fix MRope because of `get_input_positions_tensor` to
`get_mrope_input_positions`, caused by
https://github.com/vllm-project/vllm/pull/24172
- Fix `splitting_ops` changes introduced by
https://github.com/vllm-project/vllm/pull/25845
- Fix multi-modality changes introduced by
https://github.com/vllm-project/vllm/issues/16229
- Fix lora bias dropping issue introduced by
https://github.com/vllm-project/vllm/pull/25807
- Fix structured ouput break introduced by
https://github.com/vllm-project/vllm/issues/26737

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
CI passed with existing test.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: Icey <1790571317@qq.com>
Co-authored-by: Icey <1790571317@qq.com>
2025-10-24 16:55:08 +08:00

852 lines
36 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/vllm/worker/gpu_input_batch.py
#
from dataclasses import dataclass
from typing import Optional, cast
import numpy as np
import torch
from typing_extensions import deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import (MultiModalFeatureSpec,
MultiModalKwargsItem,
MultiModalKwargsItems, PlaceholderRange)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
LogitsProcessors,
MoveDirectionality)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm_ascend.utils import vllm_version_is
from vllm_ascend.worker.block_table import MultiGroupBlockTable
if vllm_version_is("0.11.0"):
from vllm.utils import swap_dict_values
else:
from vllm.utils.collections import swap_dict_values
@dataclass
class CachedRequestState:
req_id: str
prompt_token_ids: list[int]
sampling_params: Optional[SamplingParams]
pooling_params: Optional[PoolingParams]
generator: Optional[torch.Generator]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
output_token_ids: list[int]
mrope_positions: Optional[torch.Tensor] = None
mrope_position_delta: Optional[int] = None
mm_features: Optional[list[MultiModalFeatureSpec]] = None
# for back-compatibility, will be removed in next major release
mm_kwargs: Optional[list[MultiModalKwargsItem]] = None
mm_positions: Optional[list[PlaceholderRange]] = None
mm_hashes: Optional[list[PlaceholderRange]] = None
lora_request: Optional[LoRARequest] = None
def __post_init__(self):
self.num_prompt_tokens = len(self.prompt_token_ids)
@property
def num_tokens(self) -> int:
return self.num_prompt_tokens + len(self.output_token_ids)
# Temporary back-compatibility for plugins that define model runner
@property
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
"removed in v0.13. Please use `mm_kwargs` instead.")
def mm_inputs(self) -> list[MultiModalKwargsItems]:
assert self.mm_features is not None
return [
MultiModalKwargsItems.from_seq([f.data]) for f in self.mm_features
if f.data is not None
]
def get_token_id(self, idx: int) -> int:
if idx < self.num_prompt_tokens:
return self.prompt_token_ids[idx]
else:
return self.output_token_ids[idx - self.num_prompt_tokens]
class InputBatch:
def __init__(
self,
max_num_reqs: int,
max_model_len: int,
max_num_batched_tokens: int,
device: torch.device,
pin_memory: bool,
vocab_size: int,
block_sizes: list[int], # The block_size of each kv cache group
logitsprocs: Optional[LogitsProcessors] = None,
is_spec_decode: bool = False,
is_pooling_model: bool = False,
num_speculative_tokens: int = 0,
kernel_block_sizes: Optional[list[list[int]]] = None,
cp_kv_cache_interleave_size: int = 1,
):
self.is_pooling_model = is_pooling_model
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
self.max_num_batched_tokens = max_num_batched_tokens
self.device = device
self.pin_memory = pin_memory
self.vocab_size = vocab_size
self._req_ids: list[Optional[str]] = []
self.req_id_to_index: dict[str, int] = {}
# TODO(woosuk): This buffer could be too large if max_model_len is big.
# Find a way to reduce the CPU memory usage.
# This buffer is not directly transferred to the NPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
)
self.token_ids_cpu = self.token_ids_cpu_tensor.numpy()
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
max_num_reqs=max_num_reqs,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
pin_memory=pin_memory,
device=device,
block_sizes=block_sizes,
num_speculative_tokens=num_speculative_tokens,
kernel_sizes=kernel_block_sizes,
cp_kv_cache_interleave_size=cp_kv_cache_interleave_size,
)
# Sampling-related.
self.temperature = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.temperature_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.temperature_cpu = self.temperature_cpu_tensor.numpy()
self.greedy_reqs: set[str] = set()
self.random_reqs: set[str] = set()
self.top_p = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device=device)
self.top_p_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float32,
device="cpu",
pin_memory=pin_memory)
self.top_p_cpu = self.top_p_cpu_tensor.numpy()
self.top_p_reqs: set[str] = set()
self.top_k = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device=device)
self.top_k_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.int32,
device="cpu",
pin_memory=pin_memory)
self.top_k_cpu = self.top_k_cpu_tensor.numpy()
self.top_k_reqs: set[str] = set()
# IDs of requests which do not support spec decoding
self.spec_decode_unsupported_reqs: set[str] = set()
# Frequency penalty related data structures
self.frequency_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.frequency_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.frequency_penalties_cpu = \
self.frequency_penalties_cpu_tensor.numpy()
self.frequency_penalties_reqs: set[str] = set()
# Presence penalty related data structures
self.presence_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.presence_penalties_cpu_tensor = torch.empty((max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.presence_penalties_cpu = self.presence_penalties_cpu_tensor.numpy(
)
self.presence_penalties_reqs: set[str] = set()
# Repetition penalty related data structures
self.repetition_penalties = torch.empty((max_num_reqs, ),
dtype=torch.float,
device=device)
self.repetition_penalties_cpu_tensor = torch.empty(
(max_num_reqs, ),
dtype=torch.float,
device="cpu",
pin_memory=pin_memory)
self.repetition_penalties_cpu = \
self.repetition_penalties_cpu_tensor.numpy()
self.repetition_penalties_reqs: set[str] = set()
# Speculative decoding
self.num_accepted_tokens_cpu_tensor = torch.ones((max_num_reqs, ),
dtype=torch.int64,
device="cpu",
pin_memory=pin_memory)
self.num_accepted_tokens_cpu = \
self.num_accepted_tokens_cpu_tensor.numpy()
# lora related
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
dtype=np.int32)
self.lora_id_to_request_ids: dict[int, set[str]] = {}
self.lora_id_to_lora_request: dict[int, LoRARequest] = {}
# req_index -> generator
# NOTE(woosuk): The indices of the requests that do not have their own
# generator should not be included in the dictionary.
self.generators: dict[int, torch.Generator] = {}
self.num_logprobs: dict[str, int] = {}
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}
# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
self.batch_update_builder = BatchUpdateBuilder()
# TODO convert this to LogitsProcessor
self.has_allowed_token_ids: set[str] = set()
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
# the value is False. Since we use masked_fill_ to set -inf.
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
dtype=bool)
self.req_output_token_ids: list[Optional[list[int]]] = []
# Store provided logitsprocs. If none are provided, initialize empty
# data structure
self.logitsprocs = logitsprocs or LogitsProcessors()
# This is updated each time the batch constituents change.
self.sampling_metadata = self._make_sampling_metadata()
self.pooling_params: dict[str, PoolingParams] = {}
# Cached reference to the GPU tensor of previously sampled tokens
self.prev_sampled_token_ids: Optional[torch.Tensor] = None
self.prev_sampled_token_ids_invalid_indices: Optional[set[int]] = None
self.prev_req_id_to_index: Optional[dict[str, int]] = None
@property
def req_ids(self) -> list[str]:
# None elements should only be present transiently
# while performing state updates to the batch.
return cast(list[str], self._req_ids)
def _register_add_request(self, request: "CachedRequestState") -> int:
"""Track add-request operations for logits processors.
Not applicable to pooling models.
"""
# Detailed added request metadata is only required for non-pooling
# models, to support logitsprocs
assert request.sampling_params
# Fill the next empty index if there is one.
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
# Append to end otherwise.
new_req_index = self.num_reqs
assert new_req_index < self.max_num_reqs
self.batch_update_builder.added.append(
(new_req_index, request.sampling_params, request.prompt_token_ids,
request.output_token_ids))
return new_req_index
def add_request(
self,
request: "CachedRequestState",
) -> int:
if not self.is_pooling_model:
# New request index bookkeeping for autoregressive models.
req_index = self._register_add_request(request)
else:
req_index = self.num_reqs
req_id = request.req_id
if req_index == len(self._req_ids):
self._req_ids.append(req_id)
self.req_output_token_ids.append(request.output_token_ids)
else:
self._req_ids[req_index] = req_id
self.req_output_token_ids[req_index] = request.output_token_ids
self.req_id_to_index[req_id] = req_index
# Copy the prompt token ids and output token ids.
num_prompt_tokens = len(request.prompt_token_ids)
self.num_prompt_tokens[req_index] = num_prompt_tokens
self.token_ids_cpu[
req_index, :num_prompt_tokens] = request.prompt_token_ids
start_idx = num_prompt_tokens
end_idx = start_idx + len(request.output_token_ids)
self.token_ids_cpu[req_index,
start_idx:end_idx] = request.output_token_ids
# Number of token ids in token_ids_cpu.
# NOTE(woosuk): This may include spec decode tokens.
self.num_tokens[req_index] = request.num_tokens
# Number of tokens without spec decode tokens.
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
if (self.is_spec_decode
and is_spec_decode_unsupported(sampling_params)):
self.spec_decode_unsupported_reqs.add(req_id)
if sampling_params.sampling_type == SamplingType.GREEDY:
# Avoid later division by zero.
self.temperature_cpu[req_index] = -1.0
self.greedy_reqs.add(req_id)
else:
self.temperature_cpu[req_index] = sampling_params.temperature
self.random_reqs.add(req_id)
self.top_p_cpu[req_index] = sampling_params.top_p
if sampling_params.top_p < 1:
self.top_p_reqs.add(req_id)
top_k = sampling_params.top_k
if 0 < top_k < self.vocab_size:
self.top_k_reqs.add(req_id)
else:
top_k = self.vocab_size
self.top_k_cpu[req_index] = top_k
self.frequency_penalties_cpu[
req_index] = sampling_params.frequency_penalty
if sampling_params.frequency_penalty != 0.0:
self.frequency_penalties_reqs.add(req_id)
self.presence_penalties_cpu[
req_index] = sampling_params.presence_penalty
if sampling_params.presence_penalty != 0.0:
self.presence_penalties_reqs.add(req_id)
self.repetition_penalties_cpu[
req_index] = sampling_params.repetition_penalty
if sampling_params.repetition_penalty != 1.0:
self.repetition_penalties_reqs.add(req_id)
# NOTE(woosuk): self.generators should not include the requests that
# do not have their own generator.
if request.generator is not None:
self.generators[req_index] = request.generator
if sampling_params.logprobs is not None:
self.num_logprobs[req_id] = (self.vocab_size
if sampling_params.logprobs == -1
else sampling_params.logprobs)
if sampling_params.prompt_logprobs is not None:
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs
if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
# Lazy allocation for this tensor, which can be large.
# False means we don't fill with -inf.
self.allowed_token_ids_mask = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device=self.device)
self.allowed_token_ids_mask_cpu_tensor = torch.zeros(
self.max_num_reqs,
self.vocab_size,
dtype=torch.bool,
device="cpu")
self.allowed_token_ids_mask_cpu_tensor[req_index] = True
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
if sampling_params.bad_words_token_ids:
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
elif pooling_params := request.pooling_params:
self.pooling_params[req_id] = pooling_params
self.logits_processing_needs_token_ids[req_index] = (
pooling_params.requires_token_ids)
else:
raise NotImplementedError(request)
# Speculative decoding: by default 1 token is generated.
self.num_accepted_tokens_cpu[req_index] = 1
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
if lora_id not in self.lora_id_to_request_ids:
self.lora_id_to_request_ids[lora_id] = set()
self.request_lora_mapping[req_index] = lora_id
self.lora_id_to_request_ids[lora_id].add(request.req_id)
self.lora_id_to_lora_request[lora_id] = request.lora_request
else:
# No LoRA
self.request_lora_mapping[req_index] = 0
return req_index
def remove_request(self, req_id: str) -> Optional[int]:
"""This method must always be followed by a call to condense().
Args:
req_id: request to remove
Returns:
Removed request index, or `None` if `req_id` not recognized
"""
req_index = self.req_id_to_index.pop(req_id, None)
if req_index is None:
return None
if not self.is_pooling_model:
# Autoregressive models require bookkeeping of removed requests to
# support logitsprocs.
self.batch_update_builder.removed_append(req_index)
self._req_ids[req_index] = None
self.req_output_token_ids[req_index] = None
self.greedy_reqs.discard(req_id)
self.random_reqs.discard(req_id)
self.top_p_reqs.discard(req_id)
self.top_k_reqs.discard(req_id)
self.spec_decode_unsupported_reqs.discard(req_id)
self.frequency_penalties_reqs.discard(req_id)
self.presence_penalties_reqs.discard(req_id)
self.repetition_penalties_reqs.discard(req_id)
self.generators.pop(req_index, None)
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
# LoRA
lora_id = self.request_lora_mapping[req_index]
if lora_id != 0:
self.lora_id_to_request_ids[lora_id].discard(req_id)
if len(self.lora_id_to_request_ids[lora_id]) == 0:
self.lora_id_to_request_ids.pop(lora_id)
self.lora_id_to_lora_request.pop(lora_id)
self.request_lora_mapping[req_index] = 0
self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
self.pooling_params.pop(req_id, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
# For autoregressive models, track detailed request reordering info
# to support logitsprocs
self.batch_update_builder.moved.append(
(i1, i2, MoveDirectionality.SWAP))
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.num_accepted_tokens_cpu[i1], self.num_accepted_tokens_cpu[i2] =\
self.num_accepted_tokens_cpu[i2], self.num_accepted_tokens_cpu[i1]
# NOTE: the following is unsafe
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
# instead, we need to temporiarily copy the data for one of the indices
# TODO(lucas): optimize this by only copying valid indices
tmp = self.token_ids_cpu[i1, ...].copy()
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[i1], \
self.allowed_token_ids_mask_cpu_tensor[i2] =\
self.allowed_token_ids_mask_cpu_tensor[i2], \
self.allowed_token_ids_mask_cpu_tensor[i1]
self.block_table.swap_row(i1, i2)
def condense(self) -> None:
"""Slide non-empty requests down into lower, empty indices.
Any consecutive empty indices at the very end of the list are not
filled.
Args:
empty_req_indices: empty indices which may be filled.
Returns:
swaps: list of (from,to) swap tuples for moved requests
empty_req_indices: indices not filled by condensation
"""
num_reqs = self.num_reqs
if self.is_pooling_model:
# Will be contiguous in pooling case, just trim the lists.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
return
if not (empty_req_indices := self.batch_update_builder.removed):
# All removed requests were replaced by added requests, or else no
# requests were removed at all. No condense() needed
return
if num_reqs == 0:
# The batched states are empty.
self._req_ids.clear()
self.req_output_token_ids.clear()
return
# NOTE(woosuk): This function assumes that the empty_req_indices
# is sorted in descending order.
last_req_index = num_reqs + len(empty_req_indices) - 1
while empty_req_indices:
# Find the largest non-empty index.
while last_req_index in empty_req_indices:
last_req_index -= 1
# Find the smallest empty index.
empty_index = self.batch_update_builder.peek_removed()
assert empty_index is not None
if empty_index >= last_req_index:
break
# Move active request down into empty request
# index.
self.batch_update_builder.pop_removed()
# Autoregressive models require detailed tracking of condense
# operations to support logitsprocs
self.batch_update_builder.moved.append(
(last_req_index, empty_index,
MoveDirectionality.UNIDIRECTIONAL))
req_id = self._req_ids[last_req_index]
output_token_ids = self.req_output_token_ids[last_req_index]
assert req_id is not None
self._req_ids[empty_index] = req_id
self._req_ids[last_req_index] = None
self.req_output_token_ids[empty_index] = output_token_ids
self.req_output_token_ids[last_req_index] = None
self.req_id_to_index[req_id] = empty_index
num_tokens = self.num_tokens[last_req_index]
self.token_ids_cpu[empty_index, :num_tokens] = self.token_ids_cpu[
last_req_index, :num_tokens]
self.num_tokens[empty_index] = num_tokens
self.num_tokens_no_spec[empty_index] = self.num_tokens_no_spec[
last_req_index]
self.num_prompt_tokens[empty_index] = self.num_prompt_tokens[
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index]
self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index]
self.frequency_penalties_cpu[
empty_index] = self.frequency_penalties_cpu[last_req_index]
self.presence_penalties_cpu[
empty_index] = self.presence_penalties_cpu[last_req_index]
self.repetition_penalties_cpu[
empty_index] = self.repetition_penalties_cpu[last_req_index]
self.num_accepted_tokens_cpu[
empty_index] = self.num_accepted_tokens_cpu[last_req_index]
generator = self.generators.pop(last_req_index, None)
if generator is not None:
self.generators[empty_index] = generator
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
last_req_index]
# TODO convert these to LogitsProcessors
if self.allowed_token_ids_mask_cpu_tensor is not None:
self.allowed_token_ids_mask_cpu_tensor[
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
# Trim lists to the batch size.
del self._req_ids[num_reqs:]
del self.req_output_token_ids[num_reqs:]
def refresh_metadata(self):
"""Apply any batch updates to sampling metadata."""
if self.is_pooling_model:
# Batch changes every step for pooling models.
self.sampling_metadata = self._make_sampling_metadata()
return
# For non-pooling models - generate and apply logitsprocs update;
# reset batch update tracking.
# Update sampling metadata if batch state is changed.
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
self.sampling_metadata = self._make_sampling_metadata()
def _make_sampling_metadata(self) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
needs_prompt_token_ids = (
not self.no_penalties
or self.logits_processing_needs_token_ids[:num_reqs].any())
if needs_prompt_token_ids:
# The prompt tokens are used only for applying penalties or
# step pooling during the sampling/pooling process.
# Hence copy these tensors only when there are requests which
# need penalties/step_pooler to be applied.
prompt_token_ids = self._make_prompt_token_ids_tensor()
else:
prompt_token_ids = None
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
)
@property
def pooling_metadata(self) -> PoolingMetadata:
if len(self.pooling_params) == 0:
pooling_params = []
else:
# Note, for now this assumes that all request in the batch
# are either sampling or pooling requests
assert len(self.req_ids) == len(self.pooling_params)
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
prompt_token_ids_cpu_tensor = torch.empty(
(self.num_reqs, max_prompt_len),
device="cpu",
dtype=torch.int64,
pin_memory=self.pin_memory,
)
prompt_token_ids = prompt_token_ids_cpu_tensor.numpy()
prompt_token_ids[:] = self.token_ids_cpu[:self.
num_reqs, :max_prompt_len]
# Use the value of vocab_size as a pad since we don't have a
# token_id of this value.
for i in range(self.num_reqs):
prompt_token_ids[i, self.num_prompt_tokens[i]:] = self.vocab_size
return prompt_token_ids_cpu_tensor.to(device=self.device,
non_blocking=True)
def make_lora_inputs(
self, num_scheduled_tokens: np.ndarray
) -> tuple[tuple[int, ...], tuple[int, ...], set[LoRARequest]]:
"""
Given the num_scheduled_tokens for each request in the batch, return
datastructures used to activate the current LoRAs.
Returns:
1. prompt_lora_mapping: A tuple of size self.num_reqs where,
prompt_lora_mapping[i] is the LoRA id to use for the ith prompt.
2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens)
where, token_lora_mapping[i] is the LoRA id to use for ith token.
3. lora_requests: Set of relevant LoRA requests.
"""
req_lora_mapping = self.request_lora_mapping[:self.num_reqs]
prompt_lora_mapping = tuple(req_lora_mapping)
token_lora_mapping = tuple(
req_lora_mapping.repeat(num_scheduled_tokens))
active_lora_requests: set[LoRARequest] = set(
self.lora_id_to_lora_request.values())
return prompt_lora_mapping, token_lora_mapping, active_lora_requests
@property
def num_reqs(self) -> int:
return len(self.req_id_to_index)
@property
def all_greedy(self) -> bool:
return len(self.random_reqs) == 0
@property
def all_random(self) -> bool:
return len(self.greedy_reqs) == 0
@property
def no_top_p(self) -> bool:
return len(self.top_p_reqs) == 0
@property
def no_top_k(self) -> bool:
return len(self.top_k_reqs) == 0
@property
def no_penalties(self) -> bool:
return (len(self.presence_penalties_reqs) == 0
and len(self.frequency_penalties_reqs) == 0
and len(self.repetition_penalties_reqs) == 0)
@property
def max_num_logprobs(self) -> Optional[int]:
return max(self.num_logprobs.values()) if self.num_logprobs else None
@property
def no_prompt_logprob(self) -> bool:
return not self.num_prompt_logprobs
@property
def no_allowed_token_ids(self) -> bool:
return len(self.has_allowed_token_ids) == 0