Fix some ci issue and refactor modelrunner (#2445)
### What this PR does / why we need it?
Fix some ci issue and refactor modelrunner
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
CI passed with existing test.
- vLLM version: v0.10.0
- vLLM main:
4d9c61993a
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Signed-off-by: weiguihua2 <weiguihua2@huawei.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
Co-authored-by: weiguihua2 <weiguihua2@huawei.com>
This commit is contained in:
@@ -22,28 +22,30 @@ from typing import Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import deprecated
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.inputs import (MultiModalKwargs, MultiModalKwargsItem,
|
||||
PlaceholderRange)
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import init_builtin_logitsprocs
|
||||
from vllm.v1.sample.logits_processor import (BatchUpdateBuilder,
|
||||
LogitsProcessors,
|
||||
MoveDirectionality)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
|
||||
req_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargs]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
@@ -65,6 +67,13 @@ class CachedRequestState:
|
||||
def num_tokens(self) -> int:
|
||||
return self.num_prompt_tokens + len(self.output_token_ids)
|
||||
|
||||
# Temporary back-compatibility for plugins that define model runner
|
||||
@property
|
||||
@deprecated("`mm_inputs` is superseded by `mm_kwargs` and will be "
|
||||
"removed in v0.13. Please use `mm_kwargs` instead.")
|
||||
def mm_inputs(self) -> list[MultiModalKwargs]:
|
||||
return [MultiModalKwargs([item]) for item in self.mm_kwargs]
|
||||
|
||||
def get_token_id(self, idx: int) -> int:
|
||||
if idx < self.num_prompt_tokens:
|
||||
return self.prompt_token_ids[idx]
|
||||
@@ -83,8 +92,11 @@ class InputBatch:
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_sizes: list[int], # The block_size of each kv cache group
|
||||
logitsprocs: Optional[LogitsProcessors] = None,
|
||||
is_spec_decode: bool = False,
|
||||
is_pooling_model: bool = False,
|
||||
):
|
||||
self.is_pooling_model = is_pooling_model
|
||||
self.is_spec_decode = is_spec_decode
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
@@ -164,16 +176,6 @@ class InputBatch:
|
||||
# IDs of requests which do not support spec decoding
|
||||
self.spec_decode_unsupported_reqs: set[str] = set()
|
||||
|
||||
self.min_p = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device=device)
|
||||
self.min_p_cpu_tensor = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float32,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory)
|
||||
self.min_p_cpu = self.min_p_cpu_tensor.numpy()
|
||||
self.min_p_reqs: set[str] = set()
|
||||
|
||||
# Frequency penalty related data structures
|
||||
self.frequency_penalties = torch.empty((max_num_reqs, ),
|
||||
dtype=torch.float,
|
||||
@@ -212,9 +214,6 @@ class InputBatch:
|
||||
self.repetition_penalties_cpu_tensor.numpy()
|
||||
self.repetition_penalties_reqs: set[str] = set()
|
||||
|
||||
# req_index -> (min_tokens, stop_token_ids)
|
||||
self.min_tokens: dict[int, tuple[int, set[int]]] = {}
|
||||
|
||||
# lora related
|
||||
self.request_lora_mapping = np.zeros((self.max_num_reqs, ),
|
||||
dtype=np.int32)
|
||||
@@ -234,8 +233,12 @@ class InputBatch:
|
||||
# To accumulate prompt logprobs tensor chunks across prefill steps.
|
||||
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}
|
||||
|
||||
self.logit_bias: list[Optional[dict[int,
|
||||
float]]] = [None] * max_num_reqs
|
||||
# Internal representation of per-step batch state changes, used for
|
||||
# reordering persistent batch and generating logitsprocs batch state
|
||||
# updates. Should reset each step.
|
||||
self.batch_update_builder = BatchUpdateBuilder()
|
||||
|
||||
# TODO convert this to LogitsProcessor
|
||||
self.has_allowed_token_ids: set[str] = set()
|
||||
# NOTE(lufang): In the mask tensor, if the corresponding token allowed,
|
||||
# the value is False. Since we use masked_fill_ to set -inf.
|
||||
@@ -244,18 +247,15 @@ class InputBatch:
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.logits_processing_needs_token_ids = np.zeros(max_num_reqs,
|
||||
dtype=bool)
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# Define logits processors.
|
||||
# TODO(andy): logits processor list should be extensible via engine
|
||||
# constructor argument; for now the list is fixed.
|
||||
self.logitsprocs = init_builtin_logitsprocs(
|
||||
pin_memory_available=pin_memory,
|
||||
max_num_reqs=max_num_reqs + 1,
|
||||
device=device)
|
||||
# Store provided logitsprocs. If none are provided, initialize empty
|
||||
# data structure
|
||||
self.logitsprocs = logitsprocs or LogitsProcessors()
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
@@ -268,14 +268,35 @@ class InputBatch:
|
||||
# while performing state updates to the batch.
|
||||
return cast(list[str], self._req_ids)
|
||||
|
||||
def _register_add_request(self, request: "CachedRequestState") -> int:
|
||||
"""Track add-request operations for logits processors.
|
||||
Not applicable to pooling models.
|
||||
"""
|
||||
|
||||
# Detailed added request metadata is only required for non-pooling
|
||||
# models, to support logitsprocs
|
||||
assert request.sampling_params
|
||||
|
||||
# Fill the next empty index if there is one.
|
||||
if (new_req_index := self.batch_update_builder.pop_removed()) is None:
|
||||
# Append to end otherwise.
|
||||
new_req_index = self.num_reqs
|
||||
|
||||
assert new_req_index < self.max_num_reqs
|
||||
self.batch_update_builder.added.append(
|
||||
(new_req_index, request.sampling_params, request.prompt_token_ids,
|
||||
request.output_token_ids))
|
||||
return new_req_index
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
request: "CachedRequestState",
|
||||
req_index: Optional[int] = None,
|
||||
) -> None:
|
||||
if req_index is None:
|
||||
) -> int:
|
||||
if not self.is_pooling_model:
|
||||
# New request index bookkeeping for autoregressive models.
|
||||
req_index = self._register_add_request(request)
|
||||
else:
|
||||
req_index = self.num_reqs
|
||||
assert req_index < self.max_num_reqs
|
||||
|
||||
req_id = request.req_id
|
||||
if req_index == len(self._req_ids):
|
||||
@@ -306,8 +327,8 @@ class InputBatch:
|
||||
self.block_table.add_row(request.block_ids, req_index)
|
||||
|
||||
if sampling_params := request.sampling_params:
|
||||
if self.is_spec_decode and is_spec_decode_unsupported(
|
||||
sampling_params):
|
||||
if (self.is_spec_decode
|
||||
and is_spec_decode_unsupported(sampling_params)):
|
||||
self.spec_decode_unsupported_reqs.add(req_id)
|
||||
if sampling_params.sampling_type == SamplingType.GREEDY:
|
||||
# Avoid later division by zero.
|
||||
@@ -326,11 +347,8 @@ class InputBatch:
|
||||
else:
|
||||
top_k = self.vocab_size
|
||||
self.top_k_cpu[req_index] = top_k
|
||||
self.min_p_cpu[req_index] = sampling_params.min_p
|
||||
self.frequency_penalties_cpu[
|
||||
req_index] = sampling_params.frequency_penalty
|
||||
if sampling_params.min_p > _SAMPLING_EPS:
|
||||
self.min_p_reqs.add(req_id)
|
||||
if sampling_params.frequency_penalty != 0.0:
|
||||
self.frequency_penalties_reqs.add(req_id)
|
||||
self.presence_penalties_cpu[
|
||||
@@ -341,10 +359,6 @@ class InputBatch:
|
||||
req_index] = sampling_params.repetition_penalty
|
||||
if sampling_params.repetition_penalty != 1.0:
|
||||
self.repetition_penalties_reqs.add(req_id)
|
||||
if sampling_params.min_tokens:
|
||||
self.min_tokens[req_index] = (
|
||||
sampling_params.min_tokens,
|
||||
sampling_params.all_stop_token_ids)
|
||||
|
||||
# NOTE(woosuk): self.generators should not include the requests that
|
||||
# do not have their own generator.
|
||||
@@ -352,12 +366,12 @@ class InputBatch:
|
||||
self.generators[req_index] = request.generator
|
||||
|
||||
if sampling_params.logprobs is not None:
|
||||
self.num_logprobs[req_id] = sampling_params.logprobs
|
||||
self.num_logprobs[req_id] = (self.vocab_size
|
||||
if sampling_params.logprobs == -1
|
||||
else sampling_params.logprobs)
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
self.num_prompt_logprobs[
|
||||
req_id] = sampling_params.prompt_logprobs
|
||||
if sampling_params.logit_bias is not None:
|
||||
self.logit_bias[req_index] = sampling_params.logit_bias
|
||||
|
||||
if sampling_params.allowed_token_ids:
|
||||
self.has_allowed_token_ids.add(req_id)
|
||||
@@ -402,12 +416,25 @@ class InputBatch:
|
||||
# No LoRA
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
return req_index
|
||||
|
||||
def remove_request(self, req_id: str) -> Optional[int]:
|
||||
"""This method must always be followed by a call to condense()."""
|
||||
"""This method must always be followed by a call to condense().
|
||||
|
||||
Args:
|
||||
req_id: request to remove
|
||||
|
||||
Returns:
|
||||
Removed request index, or `None` if `req_id` not recognized
|
||||
"""
|
||||
|
||||
req_index = self.req_id_to_index.pop(req_id, None)
|
||||
if req_index is None:
|
||||
return None
|
||||
if not self.is_pooling_model:
|
||||
# Autoregressive models require bookkeeping of removed requests to
|
||||
# support logitsprocs.
|
||||
self.batch_update_builder.removed_append(req_index)
|
||||
self._req_ids[req_index] = None
|
||||
self.req_output_token_ids[req_index] = None
|
||||
|
||||
@@ -415,12 +442,10 @@ class InputBatch:
|
||||
self.random_reqs.discard(req_id)
|
||||
self.top_p_reqs.discard(req_id)
|
||||
self.top_k_reqs.discard(req_id)
|
||||
self.min_p_reqs.discard(req_id)
|
||||
self.min_tokens.pop(req_index, None)
|
||||
self.spec_decode_unsupported_reqs.discard(req_id)
|
||||
self.frequency_penalties_reqs.discard(req_id)
|
||||
self.presence_penalties_reqs.discard(req_id)
|
||||
self.repetition_penalties_reqs.discard(req_id)
|
||||
self.spec_decode_unsupported_reqs.discard(req_id)
|
||||
self.generators.pop(req_index, None)
|
||||
self.num_logprobs.pop(req_id, None)
|
||||
self.num_prompt_logprobs.pop(req_id, None)
|
||||
@@ -435,7 +460,6 @@ class InputBatch:
|
||||
self.lora_id_to_lora_request.pop(lora_id)
|
||||
self.request_lora_mapping[req_index] = 0
|
||||
|
||||
self.logit_bias[req_index] = None
|
||||
self.has_allowed_token_ids.discard(req_id)
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
@@ -445,6 +469,10 @@ class InputBatch:
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
# For autoregressive models, track detailed request reordering info
|
||||
# to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(i1, i2, MoveDirectionality.SWAP))
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
@@ -474,8 +502,6 @@ class InputBatch:
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
@@ -487,13 +513,10 @@ class InputBatch:
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
@@ -502,13 +525,31 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
def condense(self) -> None:
|
||||
"""Slide non-empty requests down into lower, empty indices.
|
||||
|
||||
Any consecutive empty indices at the very end of the list are not
|
||||
filled.
|
||||
|
||||
Args:
|
||||
empty_req_indices: empty batch indices, sorted descending.
|
||||
empty_req_indices: empty indices which may be filled.
|
||||
|
||||
Returns:
|
||||
swaps: list of (from,to) swap tuples for moved requests
|
||||
empty_req_indices: indices not filled by condensation
|
||||
"""
|
||||
num_reqs = self.num_reqs
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Will be contiguous in pooling case, just trim the lists.
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
return
|
||||
|
||||
if not (empty_req_indices := self.batch_update_builder.removed):
|
||||
# All removed requests were replaced by added requests, or else no
|
||||
# requests were removed at all. No condense() needed
|
||||
return
|
||||
if num_reqs == 0:
|
||||
# The batched states are empty.
|
||||
self._req_ids.clear()
|
||||
@@ -524,11 +565,19 @@ class InputBatch:
|
||||
last_req_index -= 1
|
||||
|
||||
# Find the smallest empty index.
|
||||
empty_index = empty_req_indices.pop()
|
||||
empty_index = self.batch_update_builder.peek_removed()
|
||||
assert empty_index is not None
|
||||
if empty_index >= last_req_index:
|
||||
break
|
||||
|
||||
# Swap the states.
|
||||
# Move active request down into empty request
|
||||
# index.
|
||||
self.batch_update_builder.pop_removed()
|
||||
# Autoregressive models require detailed tracking of condense
|
||||
# operations to support logitsprocs
|
||||
self.batch_update_builder.moved.append(
|
||||
(last_req_index, empty_index,
|
||||
MoveDirectionality.UNIDIRECTIONAL))
|
||||
req_id = self._req_ids[last_req_index]
|
||||
output_token_ids = self.req_output_token_ids[last_req_index]
|
||||
assert req_id is not None
|
||||
@@ -559,20 +608,14 @@ class InputBatch:
|
||||
empty_index] = self.presence_penalties_cpu[last_req_index]
|
||||
self.repetition_penalties_cpu[
|
||||
empty_index] = self.repetition_penalties_cpu[last_req_index]
|
||||
self.min_p_cpu[empty_index] = self.min_p_cpu[last_req_index]
|
||||
generator = self.generators.pop(last_req_index, None)
|
||||
if generator is not None:
|
||||
self.generators[empty_index] = generator
|
||||
|
||||
min_token = self.min_tokens.pop(last_req_index, None)
|
||||
if min_token is not None:
|
||||
self.min_tokens[empty_index] = min_token
|
||||
|
||||
self.request_lora_mapping[empty_index] = self.request_lora_mapping[
|
||||
last_req_index]
|
||||
|
||||
self.logit_bias[empty_index] = self.logit_bias[last_req_index]
|
||||
|
||||
# TODO convert these to LogitsProcessors
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
@@ -582,15 +625,30 @@ class InputBatch:
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
# Trim lists to the batch size.
|
||||
del self._req_ids[self.num_reqs:]
|
||||
del self.req_output_token_ids[self.num_reqs:]
|
||||
del self._req_ids[num_reqs:]
|
||||
del self.req_output_token_ids[num_reqs:]
|
||||
|
||||
def refresh_sampling_metadata(self):
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
def refresh_metadata(self):
|
||||
"""Apply any batch updates to sampling metadata."""
|
||||
|
||||
if self.is_pooling_model:
|
||||
# Batch changes every step for pooling models.
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
return
|
||||
|
||||
# For non-pooling models - generate and apply logitsprocs update;
|
||||
# reset batch update tracking.
|
||||
# Update sampling metadata if batch state is changed.
|
||||
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
|
||||
for logit_proc in self.logitsprocs.all:
|
||||
logit_proc.update_state(batch_update)
|
||||
if batch_update:
|
||||
self.sampling_metadata = self._make_sampling_metadata()
|
||||
|
||||
def _make_sampling_metadata(self) -> SamplingMetadata:
|
||||
num_reqs = self.num_reqs
|
||||
@@ -603,8 +661,6 @@ class InputBatch:
|
||||
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
|
||||
if not self.no_top_k:
|
||||
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
|
||||
if not self.no_min_p:
|
||||
copy_slice(self.min_p_cpu_tensor, self.min_p, num_reqs)
|
||||
|
||||
if not self.no_penalties:
|
||||
# Since syncing these tensors is expensive only copy them
|
||||
@@ -735,10 +791,6 @@ class InputBatch:
|
||||
def no_top_k(self) -> bool:
|
||||
return len(self.top_k_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_min_p(self) -> bool:
|
||||
return len(self.min_p_reqs) == 0
|
||||
|
||||
@property
|
||||
def no_penalties(self) -> bool:
|
||||
return (len(self.presence_penalties_reqs) == 0
|
||||
|
||||
Reference in New Issue
Block a user