191 lines
8.9 KiB
Python
191 lines
8.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import json
|
|
from concurrent.futures.thread import ThreadPoolExecutor
|
|
from http import HTTPStatus
|
|
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
|
Optional, Sequence, Tuple, TypedDict, Union)
|
|
|
|
from fastapi import Request
|
|
from pydantic import Field
|
|
from starlette.datastructures import Headers
|
|
from typing_extensions import Annotated
|
|
import torch
|
|
|
|
from vllm.config import ModelConfig
|
|
from vllm.engine.protocol import EngineClient
|
|
# yapf conflicts with isort for this block
|
|
# yapf: disable
|
|
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
|
ChatTemplateContentFormatOption,
|
|
ConversationMessage,
|
|
apply_hf_chat_template,
|
|
apply_mistral_chat_template,
|
|
parse_chat_messages_futures,
|
|
resolve_chat_template_content_format)
|
|
from vllm.entrypoints.logger import RequestLogger
|
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
|
CompletionRequest,
|
|
DetokenizeRequest,
|
|
EmbeddingChatRequest,
|
|
EmbeddingCompletionRequest,
|
|
ErrorResponse, RerankRequest,
|
|
ScoreRequest,
|
|
TokenizeChatRequest,
|
|
TokenizeCompletionRequest)
|
|
# yapf: enable
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.pooling_params import PoolingParams
|
|
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
|
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
|
|
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
|
|
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
|
import os
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
|
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
|
|
|
|
class EmbedsPrompt(TypedDict):
|
|
prompt_embeds: torch.Tensor
|
|
deepstack_input_embeds: Optional[dict]
|
|
|
|
class OpenAIServing:
|
|
def _validate_input(
|
|
self,
|
|
request: AnyRequest,
|
|
input_ids: List[int],
|
|
input_text: str,
|
|
) -> TextTokensPrompt:
|
|
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
|
|
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
|
|
cut_before = CUT_PREFILL_SEQ_LEN // 2
|
|
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
|
|
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
|
|
token_num = len(input_ids)
|
|
|
|
if not self.model_config.pooler_config:
|
|
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
|
raise ValueError(
|
|
f"unsupport penalty for sampler"
|
|
f"request.repetition_penalty: {request.repetition_penalty}; "
|
|
f"Please remove penalty parameter in client and try again."
|
|
)
|
|
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
|
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
|
if request.prompt_logprobs is not None:
|
|
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
|
|
|
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
|
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
|
if request.prompt_logprobs is not None:
|
|
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
|
|
|
# model_type = self.model_config.hf_config.model_type
|
|
# if model_type == "deepseek_v3":
|
|
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
|
|
raise ValueError(
|
|
f"This model's maximum input seq length limit is "
|
|
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
|
f"({token_num} in the input messages, "
|
|
f"Please reduce the length of the input messages.")
|
|
|
|
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
|
if isinstance(request,
|
|
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
|
ScoreRequest, RerankRequest)):
|
|
|
|
operation = "score" if isinstance(request, ScoreRequest) \
|
|
else "embedding generation"
|
|
if token_num > self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{token_num} tokens in the input for {operation}. "
|
|
f"Please reduce the length of the input.")
|
|
return TextTokensPrompt(prompt=input_text,
|
|
prompt_token_ids=input_ids)
|
|
|
|
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
|
# and does not require model context length validation
|
|
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
|
DetokenizeRequest)):
|
|
return TextTokensPrompt(prompt=input_text,
|
|
prompt_token_ids=input_ids)
|
|
|
|
# chat completion endpoint supports max_completion_tokens
|
|
if isinstance(request, ChatCompletionRequest):
|
|
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
|
max_tokens = request.max_completion_tokens or request.max_tokens
|
|
else:
|
|
max_tokens = request.max_tokens
|
|
if max_tokens is None:
|
|
if token_num >= self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{token_num} tokens in the messages, "
|
|
f"Please reduce the length of the messages.")
|
|
elif token_num + max_tokens > self.max_model_len:
|
|
raise ValueError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_model_len} tokens. However, you requested "
|
|
f"{max_tokens + token_num} tokens "
|
|
f"({token_num} in the messages, "
|
|
f"{max_tokens} in the completion). "
|
|
f"Please reduce the length of the messages or completion.")
|
|
|
|
|
|
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
|
|
|
def _log_inputs(
|
|
self,
|
|
request_id: str,
|
|
inputs,
|
|
params: Optional[Union[SamplingParams, PoolingParams,
|
|
BeamSearchParams]],
|
|
lora_request: Optional[LoRARequest],
|
|
) -> None:
|
|
# move to position where before use request_logger
|
|
# if self.request_logger is None:
|
|
# return
|
|
# if self.model_config.pooler_config is not None, task is embedding , not generation task
|
|
if self.model_config.pooler_config:
|
|
return
|
|
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
|
if isinstance(inputs, str):
|
|
prompt = inputs
|
|
elif isinstance(inputs, list):
|
|
prompt_token_ids = inputs
|
|
else:
|
|
prompt = getattr(inputs, 'prompt', None)
|
|
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
|
|
|
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
|
|
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
|
logger.warning(
|
|
"\033[93mWARNING \033[0m"
|
|
": Unsupport penalty for sampler"
|
|
f"params.repetition_penalty: {params.repetition_penalty} and "
|
|
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
|
|
"Now set: repetition_penalty: 1.0"
|
|
)
|
|
# params.presence_penalty = 0
|
|
# params.frequency_penalty = 0
|
|
params.repetition_penalty = 1
|
|
|
|
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
|
|
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
|
|
params.min_p = 0
|
|
if self.request_logger is None:
|
|
return
|
|
self.request_logger.log_inputs(
|
|
request_id,
|
|
prompt,
|
|
prompt_token_ids,
|
|
prompt_embeds,
|
|
params=params,
|
|
lora_request=lora_request,
|
|
) |