103 lines
4.1 KiB
Python
103 lines
4.1 KiB
Python
|
||
|
||
import itertools
|
||
import warnings
|
||
from collections.abc import Sequence
|
||
from contextlib import contextmanager
|
||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
|
||
cast, overload)
|
||
|
||
import cloudpickle
|
||
import torch.nn as nn
|
||
from pydantic import ValidationError
|
||
from tqdm.auto import tqdm
|
||
from typing_extensions import TypeVar, deprecated
|
||
|
||
from vllm.entrypoints.utils import _validate_truncation_size
|
||
from vllm.inputs import PromptType
|
||
from vllm.logger import init_logger
|
||
from vllm.lora.request import LoRARequest
|
||
from vllm.pooling_params import PoolingParams
|
||
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
|
||
from vllm.platforms import current_platform
|
||
|
||
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
_R = TypeVar("_R", default=Any)
|
||
|
||
class LLM:
|
||
|
||
EPRECATE_LEGACY: ClassVar[bool] = True
|
||
def _validate_and_add_requests(
|
||
self,
|
||
prompts: Union[PromptType, Sequence[PromptType]],
|
||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||
Sequence[PoolingParams]],
|
||
*,
|
||
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
|
||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||
priority: Optional[list[int]] = None,
|
||
) -> None:
|
||
|
||
if isinstance(prompts, (str, dict)):
|
||
# Convert a single prompt to a list.
|
||
prompts = [prompts]
|
||
|
||
num_requests = len(prompts)
|
||
if isinstance(params, Sequence) and len(params) != num_requests:
|
||
raise ValueError("The lengths of prompts and params "
|
||
"must be the same.")
|
||
if isinstance(lora_request,
|
||
Sequence) and len(lora_request) != num_requests:
|
||
raise ValueError("The lengths of prompts and lora_request "
|
||
"must be the same.")
|
||
|
||
for sp in params if isinstance(params, Sequence) else (params, ):
|
||
if isinstance(sp, SamplingParams):
|
||
# We only care about the final output
|
||
sp.output_kind = RequestOutputKind.FINAL_ONLY
|
||
|
||
# Add requests to the engine.
|
||
it = prompts
|
||
if use_tqdm:
|
||
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
|
||
it = tqdm_func(it, desc="Adding requests")
|
||
|
||
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
|
||
batch_items = []
|
||
model_config = self.llm_engine.model_config
|
||
for i, prompt in enumerate(it):
|
||
request_id = str(next(self.request_counter))
|
||
# print("requset_id===========", request_id)
|
||
param = params[i] if isinstance(params, Sequence) else params
|
||
tokenization_kwargs: dict[str, Any] = {}
|
||
_validate_truncation_size(model_config.max_model_len,
|
||
param.truncate_prompt_tokens,
|
||
tokenization_kwargs)
|
||
|
||
batch_items.append((
|
||
request_id,
|
||
prompt,
|
||
params[i] if isinstance(params, Sequence) else params,
|
||
None, # arrival_time,不用的话传 None
|
||
(lora_request[i] if isinstance(lora_request, Sequence)
|
||
else lora_request),
|
||
tokenization_kwargs,
|
||
None, # trace_headers(如无 APM/Tracing,None)
|
||
(priority[i] if priority else 0),
|
||
))
|
||
# 一次性下发给 EngineCore(走 ADD_BULK)
|
||
self.llm_engine.add_requests(batch_items)
|
||
else:
|
||
for i, prompt in enumerate(it):
|
||
self._add_request(
|
||
prompt,
|
||
params[i] if isinstance(params, Sequence) else params,
|
||
tokenization_kwargs=tokenization_kwargs,
|
||
lora_request=lora_request[i] if isinstance(
|
||
lora_request, Sequence) else lora_request,
|
||
priority=priority[i] if priority else 0,
|
||
)
|