231 lines
8.3 KiB
Python
231 lines
8.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from functools import cached_property
|
|
from typing import TYPE_CHECKING
|
|
|
|
from typing_extensions import deprecated
|
|
|
|
from vllm._bc_linter import bc_linter_include
|
|
|
|
if TYPE_CHECKING:
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import torch
|
|
|
|
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.v1.request import Request
|
|
else:
|
|
ECConnectorMetadata = object
|
|
KVConnectorMetadata = object
|
|
LoRARequest = object
|
|
MultiModalFeatureSpec = object
|
|
PoolingParams = object
|
|
SamplingParams = object
|
|
Request = object
|
|
|
|
|
|
@bc_linter_include
|
|
@dataclass
|
|
class NewRequestData:
|
|
req_id: str
|
|
prompt_token_ids: list[int] | None
|
|
mm_features: list[MultiModalFeatureSpec]
|
|
sampling_params: SamplingParams | None
|
|
pooling_params: PoolingParams | None
|
|
block_ids: tuple[list[int], ...]
|
|
num_computed_tokens: int
|
|
lora_request: LoRARequest | None
|
|
prompt_embeds: "torch.Tensor | None" = None
|
|
|
|
# Only used for v2 model runner.
|
|
prefill_token_ids: list[int] | None = None
|
|
|
|
@classmethod
|
|
def from_request(
|
|
cls,
|
|
request: Request,
|
|
block_ids: tuple[list[int], ...],
|
|
prefill_token_ids: list[int] | None = None,
|
|
) -> "NewRequestData":
|
|
return cls(
|
|
req_id=request.request_id,
|
|
prompt_token_ids=request.prompt_token_ids,
|
|
mm_features=request.mm_features,
|
|
sampling_params=request.sampling_params,
|
|
pooling_params=request.pooling_params,
|
|
block_ids=block_ids,
|
|
num_computed_tokens=request.num_computed_tokens,
|
|
lora_request=request.lora_request,
|
|
prompt_embeds=request.prompt_embeds,
|
|
prefill_token_ids=prefill_token_ids,
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
prompt_embeds_shape = (
|
|
self.prompt_embeds.shape if self.prompt_embeds is not None else None
|
|
)
|
|
return (
|
|
f"NewRequestData("
|
|
f"req_id={self.req_id},"
|
|
f"prompt_token_ids={self.prompt_token_ids},"
|
|
f"prefill_token_ids={self.prefill_token_ids},"
|
|
f"mm_features={self.mm_features},"
|
|
f"sampling_params={self.sampling_params},"
|
|
f"block_ids={self.block_ids},"
|
|
f"num_computed_tokens={self.num_computed_tokens},"
|
|
f"lora_request={self.lora_request},"
|
|
f"prompt_embeds_shape={prompt_embeds_shape}"
|
|
")"
|
|
)
|
|
|
|
# Version of __repr__ with the prompt data obfuscated
|
|
def anon_repr(self) -> str:
|
|
prompt_token_ids_len = (
|
|
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
|
|
)
|
|
prompt_embeds_shape = (
|
|
self.prompt_embeds.shape if self.prompt_embeds is not None else None
|
|
)
|
|
return (
|
|
f"NewRequestData("
|
|
f"req_id={self.req_id},"
|
|
f"prompt_token_ids_len={prompt_token_ids_len},"
|
|
f"mm_features={self.mm_features},"
|
|
f"sampling_params={self.sampling_params},"
|
|
f"block_ids={self.block_ids},"
|
|
f"num_computed_tokens={self.num_computed_tokens},"
|
|
f"lora_request={self.lora_request},"
|
|
f"prompt_embeds_shape={prompt_embeds_shape}"
|
|
")"
|
|
)
|
|
|
|
|
|
@bc_linter_include
|
|
@dataclass
|
|
class CachedRequestData:
|
|
req_ids: list[str]
|
|
# For request ids not in resumed_req_ids, new_block_ids will be appended to
|
|
# the request's block IDs. For those in the set, new_block_ids will be used as the
|
|
# request's block IDs instead of appending to the existing block IDs.
|
|
resumed_req_ids: set[str]
|
|
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
|
# When PP is not used, new_token_ids will be empty.
|
|
new_token_ids: list[list[int]]
|
|
# For requests not scheduled in the last step, propagate the token ids to the
|
|
# connector. Won't contain requests that were scheduled in the prior step.
|
|
all_token_ids: dict[str, list[int]]
|
|
new_block_ids: list[tuple[list[int], ...] | None]
|
|
num_computed_tokens: list[int]
|
|
num_output_tokens: list[int]
|
|
|
|
@property
|
|
def num_reqs(self) -> int:
|
|
return len(self.req_ids)
|
|
|
|
@cached_property
|
|
@deprecated("This will be removed in v0.14, use `resumed_req_ids` instead.")
|
|
def resumed_from_preemption(self) -> list[bool]:
|
|
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
|
|
|
|
@cached_property
|
|
@deprecated("This will be removed in v0.14, use `all_token_ids` instead.")
|
|
def resumed_req_token_ids(self) -> list[list[int] | None]:
|
|
return [
|
|
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
|
|
for req_id in self.req_ids
|
|
]
|
|
|
|
@classmethod
|
|
def make_empty(cls) -> "CachedRequestData":
|
|
return cls(
|
|
req_ids=[],
|
|
resumed_req_ids=set(),
|
|
new_token_ids=[],
|
|
all_token_ids={},
|
|
new_block_ids=[],
|
|
num_computed_tokens=[],
|
|
num_output_tokens=[],
|
|
)
|
|
|
|
|
|
@bc_linter_include
|
|
@dataclass
|
|
class SchedulerOutput:
|
|
# list of the requests that are scheduled for the first time.
|
|
# We cache the request's data in each worker process, so that we don't
|
|
# need to re-send it every scheduling step.
|
|
scheduled_new_reqs: list[NewRequestData]
|
|
# list of the requests that have been scheduled before.
|
|
# Since the request's data is already cached in the worker processes,
|
|
# we only send the diff to minimize the communication cost.
|
|
scheduled_cached_reqs: CachedRequestData
|
|
|
|
# req_id -> num_scheduled_tokens
|
|
# Number of tokens scheduled for each request.
|
|
num_scheduled_tokens: dict[str, int]
|
|
# Total number of tokens scheduled for all requests.
|
|
# Equal to sum(num_scheduled_tokens.values())
|
|
total_num_scheduled_tokens: int
|
|
# req_id -> spec_token_ids
|
|
# If a request does not have any spec decode tokens, it will not be
|
|
# included in the dictionary.
|
|
scheduled_spec_decode_tokens: dict[str, list[int]]
|
|
# req_id -> encoder input indices that need processing.
|
|
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
|
# to process that the request's 0-th and 1-th images in the current step.
|
|
scheduled_encoder_inputs: dict[str, list[int]]
|
|
# Number of common prefix blocks for all requests in each KV cache group.
|
|
# This can be used for cascade attention.
|
|
num_common_prefix_blocks: list[int]
|
|
|
|
# Request IDs that are finished in between the previous and the current
|
|
# steps. This is used to notify the workers about the finished requests
|
|
# so that they can free the cached states for those requests.
|
|
finished_req_ids: set[str]
|
|
# list of mm_hash strings associated with the encoder outputs to be
|
|
# freed from the encoder cache.
|
|
free_encoder_mm_hashes: list[str]
|
|
|
|
# Request IDs that are preempted in this step.
|
|
# Only used for v2 model runner.
|
|
preempted_req_ids: set[str] | None = None
|
|
|
|
# Whether the scheduled requests have all the output tokens they
|
|
# need to perform grammar bitmask computation.
|
|
pending_structured_output_tokens: bool = False
|
|
|
|
# KV Cache Connector metadata.
|
|
kv_connector_metadata: KVConnectorMetadata | None = None
|
|
|
|
# EC Cache Connector metadata
|
|
ec_connector_metadata: ECConnectorMetadata | None = None
|
|
|
|
@classmethod
|
|
def make_empty(cls) -> "SchedulerOutput":
|
|
return cls(
|
|
scheduled_new_reqs=[],
|
|
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
|
num_scheduled_tokens={},
|
|
total_num_scheduled_tokens=0,
|
|
scheduled_spec_decode_tokens={},
|
|
scheduled_encoder_inputs={},
|
|
num_common_prefix_blocks=[],
|
|
finished_req_ids=set(),
|
|
free_encoder_mm_hashes=[],
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class GrammarOutput:
|
|
# ids of structured output requests.
|
|
structured_output_request_ids: list[str]
|
|
# Bitmask ordered as structured_output_request_ids.
|
|
grammar_bitmask: "npt.NDArray[np.int32]"
|