Files

103 lines
4.1 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
import itertools
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle
import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
from vllm.platforms import current_platform
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
EPRECATE_LEGACY: ClassVar[bool] = True
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
priority: Optional[list[int]] = None,
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
Sequence) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
for sp in params if isinstance(params, Sequence) else (params, ):
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
batch_items = []
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
request_id = str(next(self.request_counter))
# print("requset_id===========", request_id)
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
batch_items.append((
request_id,
prompt,
params[i] if isinstance(params, Sequence) else params,
None, # arrival_time不用的话传 None
(lora_request[i] if isinstance(lora_request, Sequence)
else lora_request),
tokenization_kwargs,
None, # trace_headers如无 APM/TracingNone
(priority[i] if priority else 0),
))
# 一次性下发给 EngineCore走 ADD_BULK
self.llm_engine.add_requests(batch_items)
else:
for i, prompt in enumerate(it):
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
)