init
This commit is contained in:
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
191
vllm_vacc/vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import json
|
||||
from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from http import HTTPStatus
|
||||
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
|
||||
Optional, Sequence, Tuple, TypedDict, Union)
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
from typing_extensions import Annotated
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
|
||||
ChatTemplateContentFormatOption,
|
||||
ConversationMessage,
|
||||
apply_hf_chat_template,
|
||||
apply_mistral_chat_template,
|
||||
parse_chat_messages_futures,
|
||||
resolve_chat_template_content_format)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingChatRequest,
|
||||
EmbeddingCompletionRequest,
|
||||
ErrorResponse, RerankRequest,
|
||||
ScoreRequest,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.entrypoints.openai.serving_engine import AnyRequest, TextTokensPrompt
|
||||
# from vllm.model_executor.sampling_metadata import _SAMPLING_EPS
|
||||
from vllm.v1.sample.sampler import _SAMPLING_EPS
|
||||
import os
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
from vllm_vacc.vllm.model_executor.models.vars import LLM_MAX_PREFILL_SEQ_LEN
|
||||
from vllm_vacc.vllm.model_executor.models.vars import CUT_PREFILL_SEQ_LEN
|
||||
|
||||
class EmbedsPrompt(TypedDict):
|
||||
prompt_embeds: torch.Tensor
|
||||
deepstack_input_embeds: Optional[dict]
|
||||
|
||||
class OpenAIServing:
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
# clint 设置的参数, 如果没有设, 还会再从 generation_config.json 读取
|
||||
if CUT_PREFILL_SEQ_LEN > 0 and CUT_PREFILL_SEQ_LEN < len(input_ids):
|
||||
cut_before = CUT_PREFILL_SEQ_LEN // 2
|
||||
cut_after = CUT_PREFILL_SEQ_LEN - cut_before
|
||||
input_ids = input_ids[:cut_before] + input_ids[(-1)*cut_after:]
|
||||
token_num = len(input_ids)
|
||||
|
||||
if not self.model_config.pooler_config:
|
||||
if (request.repetition_penalty is not None and abs(request.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
raise ValueError(
|
||||
f"unsupport penalty for sampler"
|
||||
f"request.repetition_penalty: {request.repetition_penalty}; "
|
||||
f"Please remove penalty parameter in client and try again."
|
||||
)
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
if request.min_p is not None and request.min_p > _SAMPLING_EPS:
|
||||
raise ValueError(f"unsupport min_p {request.min_p} for sampler")
|
||||
if request.prompt_logprobs is not None:
|
||||
raise ValueError(f"unsupport prompt_logprobs {request.prompt_logprobs} for sampler")
|
||||
|
||||
# model_type = self.model_config.hf_config.model_type
|
||||
# if model_type == "deepseek_v3":
|
||||
if token_num > LLM_MAX_PREFILL_SEQ_LEN:
|
||||
raise ValueError(
|
||||
f"This model's maximum input seq length limit is "
|
||||
f"{LLM_MAX_PREFILL_SEQ_LEN} tokens. However, you requested "
|
||||
f"({token_num} in the input messages, "
|
||||
f"Please reduce the length of the input messages.")
|
||||
|
||||
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest)):
|
||||
|
||||
operation = "score" if isinstance(request, ScoreRequest) \
|
||||
else "embedding generation"
|
||||
if token_num > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for {operation}. "
|
||||
f"Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# chat completion endpoint supports max_completion_tokens
|
||||
if isinstance(request, ChatCompletionRequest):
|
||||
# TODO(#9845): remove max_tokens when field dropped from OpenAI API
|
||||
max_tokens = request.max_completion_tokens or request.max_tokens
|
||||
else:
|
||||
max_tokens = request.max_tokens
|
||||
if max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.")
|
||||
elif token_num + max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs,
|
||||
params: Optional[Union[SamplingParams, PoolingParams,
|
||||
BeamSearchParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
) -> None:
|
||||
# move to position where before use request_logger
|
||||
# if self.request_logger is None:
|
||||
# return
|
||||
# if self.model_config.pooler_config is not None, task is embedding , not generation task
|
||||
if self.model_config.pooler_config:
|
||||
return
|
||||
prompt, prompt_token_ids, prompt_embeds = None, None, None
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
elif isinstance(inputs, list):
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
prompt = getattr(inputs, 'prompt', None)
|
||||
prompt_token_ids = getattr(inputs, 'prompt_token_ids', None)
|
||||
|
||||
# generation_config 读取的惩罚信息, 如果有,则警告并且修改
|
||||
if (params.repetition_penalty is not None and abs(params.repetition_penalty - 1.0) >= _SAMPLING_EPS):
|
||||
logger.warning(
|
||||
"\033[93mWARNING \033[0m"
|
||||
": Unsupport penalty for sampler"
|
||||
f"params.repetition_penalty: {params.repetition_penalty} and "
|
||||
"Please set attrs: extra_body = {\'repetition_penalty\': 1.0}\n"
|
||||
"Now set: repetition_penalty: 1.0"
|
||||
)
|
||||
# params.presence_penalty = 0
|
||||
# params.frequency_penalty = 0
|
||||
params.repetition_penalty = 1
|
||||
|
||||
if hasattr(params, "min_p") and params.min_p is not None and params.min_p > _SAMPLING_EPS:
|
||||
logger.warning(f"\033[93mWARNING \033[0m : unsupport min_p {params.min_p} for sampler")
|
||||
params.min_p = 0
|
||||
if self.request_logger is None:
|
||||
return
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
prompt_embeds,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
Reference in New Issue
Block a user