Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/entrypoints/llm.py
2026-04-02 04:55:00 +00:00

103 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import itertools
import warnings
from collections.abc import Sequence
from contextlib import contextmanager
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
cast, overload)
import cloudpickle
import torch.nn as nn
from pydantic import ValidationError
from tqdm.auto import tqdm
from typing_extensions import TypeVar, deprecated
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (RequestOutputKind, SamplingParams)
from vllm.platforms import current_platform
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
EPRECATE_LEGACY: ClassVar[bool] = True
def _validate_and_add_requests(
self,
prompts: Union[PromptType, Sequence[PromptType]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
*,
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
priority: Optional[list[int]] = None,
) -> None:
if isinstance(prompts, (str, dict)):
# Convert a single prompt to a list.
prompts = [prompts]
num_requests = len(prompts)
if isinstance(params, Sequence) and len(params) != num_requests:
raise ValueError("The lengths of prompts and params "
"must be the same.")
if isinstance(lora_request,
Sequence) and len(lora_request) != num_requests:
raise ValueError("The lengths of prompts and lora_request "
"must be the same.")
for sp in params if isinstance(params, Sequence) else (params, ):
if isinstance(sp, SamplingParams):
# We only care about the final output
sp.output_kind = RequestOutputKind.FINAL_ONLY
# Add requests to the engine.
it = prompts
if use_tqdm:
tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
it = tqdm_func(it, desc="Adding requests")
if (hasattr(current_platform, 'supports_v1') and current_platform.supports_v1(current_platform)):
batch_items = []
model_config = self.llm_engine.model_config
for i, prompt in enumerate(it):
request_id = str(next(self.request_counter))
# print("requset_id===========", request_id)
param = params[i] if isinstance(params, Sequence) else params
tokenization_kwargs: dict[str, Any] = {}
_validate_truncation_size(model_config.max_model_len,
param.truncate_prompt_tokens,
tokenization_kwargs)
batch_items.append((
request_id,
prompt,
params[i] if isinstance(params, Sequence) else params,
None, # arrival_time不用的话传 None
(lora_request[i] if isinstance(lora_request, Sequence)
else lora_request),
tokenization_kwargs,
None, # trace_headers如无 APM/TracingNone
(priority[i] if priority else 0),
))
# 一次性下发给 EngineCore走 ADD_BULK
self.llm_engine.add_requests(batch_items)
else:
for i, prompt in enumerate(it):
self._add_request(
prompt,
params[i] if isinstance(params, Sequence) else params,
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
priority=priority[i] if priority else 0,
)