Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/entrypoints/openai/serving_engine.py
2026-04-02 04:55:00 +00:00

191 lines
8.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import json
from concurrent.futures.thread import ThreadPoolExecutor
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
from fastapi import Request
from pydantic import Field
from starlette.datastructures import Headers
from typing_extensions import Annotated
import torch
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
ChatTemplateContentFormatOption,
ConversationMessage,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages_futures,
resolve_chat_template_content_format)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse, RerankRequest,
ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest)
# yapf: enable
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
from vllm.v1.sample.sampler import _SAMPLING_EPS
import os
logger = init_logger(__name__)
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
class EmbedsPrompt(TypedDict):
prompt_embeds: torch.Tensor
deepstack_input_embeds: Optional[dict]
class OpenAIServing:
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
cut_before = CUT_PREFILL_SEQ_LEN // 2
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
token_num = len(input_ids)
if not self.model_config.pooler_config:
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
raise ValueError(
f"unsupport penalty for sampler"
f"request.repetition_penalty: {request.repetition_penalty}; "
f"Please remove penalty parameter in client and try again."
)
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
if request.prompt_logprobs is not None:
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
if request.prompt_logprobs is not None:
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
# model_type = self.model_config.hf_config.model_type
# if model_type == "deepseek_v3":
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
raise ValueError(
f"This model's maximum input seq length limit is "
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
f"({token_num} in the input messages, "
f"Please reduce the length of the input messages.")
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest)):
operation = "score" if isinstance(request, ScoreRequest) \
else "embedding generation"
if token_num > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for {operation}. "
f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# chat completion endpoint supports max_completion_tokens
if isinstance(request, ChatCompletionRequest):
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
max_tokens = request.max_completion_tokens or request.max_tokens
else:
max_tokens = request.max_tokens
if max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.")
elif token_num + max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _log_inputs(
self,
request_id: str,
inputs,
params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest],
) -> None:
# move to position where before use request_logger
# if self.request_logger is None:
# return
# if self.model_config.pooler_config is not None, task is embedding , not generation task
if self.model_config.pooler_config:
return
prompt, prompt_token_ids, prompt_embeds = None, None, None
if isinstance(inputs, str):
prompt = inputs
elif isinstance(inputs, list):
prompt_token_ids = inputs
else:
prompt = getattr(inputs, 'prompt', None)
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
logger.warning(
"\033[93mWARNING \033[0m"
": Unsupport penalty for sampler"
f"params.repetition_penalty: {params.repetition_penalty} and "
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
"Now set: repetition_penalty: 1.0"
)
# params.presence_penalty = 0
# params.frequency_penalty = 0
params.repetition_penalty = 1
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
params.min_p = 0
if self.request_logger is None:
return
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
prompt_embeds,
params=params,
lora_request=lora_request,
)