# SPDX-License-Identifier: Apache-2.0 import json from concurrent.futures.thread import ThreadPoolExecutor from http import HTTPStatus from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Tuple, TypedDict, Union) from fastapi import Request from pydantic import Field from starlette.datastructures import Headers from typing_extensions import Annotated import torch from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, ChatTemplateContentFormatOption, ConversationMessage, apply_hf_chat_template, apply_mistral_chat_template, parse_chat_messages_futures, resolve_chat_template_content_format) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, EmbeddingChatRequest, EmbeddingCompletionRequest, ErrorResponse, RerankRequest, ScoreRequest, TokenizeChatRequest, TokenizeCompletionRequest) # yapf: enable from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.sampling_params import BeamSearchParams, SamplingParams from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt # from vllm.model_executor.sampling_metadata import _SAMPLING_EPS from vllm.v1.sample.sampler import _SAMPLING_EPS import os logger = init_logger(__name__) from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN class EmbedsPrompt(TypedDict): prompt_embeds: torch.Tensor deepstack_input_embeds: Optional[dict] class OpenAIServing: def _validate_input( self, request: AnyRequest, input_ids: List[int], input_text: str, ) -> TextTokensPrompt: # clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取 if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids): cut_before = CUT_PREFILL_SEQ_LEN // 2 cut_after = CUT_PREFILL_SEQ_LEN - cut_before input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:] token_num = len(input_ids) if not self.model_config.pooler_config: if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS): raise ValueError( f"unsupport penalty for sampler" f"request.repetition_penalty: {request.repetition_penalty}; " f"Please remove penalty parameter in client and try again." ) if request.min_p is not None and request.min_p > _SAMPLING_EPS: raise ValueError(f"unsupport min_p {request.min_p} for sampler") if request.prompt_logprobs is not None: raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler") if request.min_p is not None and request.min_p > _SAMPLING_EPS: raise ValueError(f"unsupport min_p {request.min_p} for sampler") if request.prompt_logprobs is not None: raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler") # model_type = self.model_config.hf_config.model_type # if model_type == "deepseek_v3": if token_num > LLM_MAX_PREFILL_SEQ_LEN: raise ValueError( f"This model's maximum input seq length limit is " f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested " f"({token_num} in the input messages, " f"Please reduce the length of the input messages.") # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens if isinstance(request, (EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest, RerankRequest)): operation = "score" if isinstance(request, ScoreRequest) \ else "embedding generation" if token_num > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for {operation}. " f"Please reduce the length of the input.") return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, DetokenizeRequest)): return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) # chat completion endpoint supports max_completion_tokens if isinstance(request, ChatCompletionRequest): # TODO(#9845): remove max_tokens when field dropped from OpenAI API max_tokens = request.max_completion_tokens or request.max_tokens else: max_tokens = request.max_tokens if max_tokens is None: if token_num >= self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") elif token_num + max_tokens > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{max_tokens + token_num} tokens " f"({token_num} in the messages, " f"{max_tokens} in the completion). " f"Please reduce the length of the messages or completion.") return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) def _log_inputs( self, request_id: str, inputs, params: Optional[Union[SamplingParams, PoolingParams, BeamSearchParams]], lora_request: Optional[LoRARequest], ) -> None: # move to position where before use request_logger # if self.request_logger is None: # return # if self.model_config.pooler_config is not None, task is embedding , not generation task if self.model_config.pooler_config: return prompt, prompt_token_ids, prompt_embeds = None, None, None if isinstance(inputs, str): prompt = inputs elif isinstance(inputs, list): prompt_token_ids = inputs else: prompt = getattr(inputs, 'prompt', None) prompt_token_ids = getattr(inputs, 'prompt_token_ids', None) # generation_config 读取的惩罚信息, 如果有,则警告并且修改 if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS): logger.warning( "\033[93mWARNING \033[0m" ": Unsupport penalty for sampler" f"params.repetition_penalty: {params.repetition_penalty} and " "Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n" "Now set: repetition_penalty: 1.0" ) # params.presence_penalty = 0 # params.frequency_penalty = 0 params.repetition_penalty = 1 if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS: logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler") params.min_p = 0 if self.request_logger is None: return self.request_logger.log_inputs( request_id, prompt, prompt_token_ids, prompt_embeds, params=params, lora_request=lora_request, )