This commit is contained in:
2025-08-07 07:25:16 +00:00
commit ae2c299b3a
117 changed files with 29475 additions and 0 deletions

View File

View File

@@ -0,0 +1,105 @@
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
"""
import argparse
import json
from typing import AsyncGenerator
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds.
app = FastAPI()
engine = None
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.post("/generate")
async def generate(request: Request) -> Response:
"""Generate completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the generation.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See `SamplingParams` for details).
"""
request_dict = await request.json()
prompt = request_dict.pop("prompt")
prefix_pos = request_dict.pop("prefix_pos", None)
stream = request_dict.pop("stream", False)
sampling_params = SamplingParams(**request_dict)
request_id = random_uuid()
results_generator = engine.generate(prompt,
sampling_params,
request_id,
prefix_pos=prefix_pos)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
if stream:
return StreamingResponse(stream_results())
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
prompt = final_output.prompt
text_outputs = [prompt + output.text for output in final_output.outputs]
ret = {"text": text_outputs}
return JSONResponse(ret)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="debug",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile)

220
vllm/entrypoints/llm.py Normal file
View File

@@ -0,0 +1,220 @@
from typing import List, Optional, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.lora.request import LoRARequest
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMEngine` class instead.
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
if available, and "slow" will always use the slow tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
the `torch_dtype` attribute specified in the model config file.
However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead.
quantization: The method used to quantize the model weights. Currently,
we support "awq", "gptq" and "squeezellm". If None, we first check
the `quantization_config` attribute in the model config file. If
that is None, we assume the model weights are not quantized and use
`dtype` to determine the data type of the weights.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id.
seed: The seed to initialize the random number generator for sampling.
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
reserve for the model weights, activations, and KV cache. Higher
values will increase the KV cache size and thus improve the model's
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Otherwise, too small values may cause out-of-memory (OOM) errors.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode.
disable_custom_all_reduce: See ParallelConfig
"""
def __init__(
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: int = 0,
gpu_memory_utilization: float = 0.9,
swap_space: int = 4,
enforce_eager: bool = False,
max_context_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
**kwargs,
) -> None:
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
tokenizer_mode=tokenizer_mode,
trust_remote_code=trust_remote_code,
tensor_parallel_size=tensor_parallel_size,
dtype=dtype,
quantization=quantization,
revision=revision,
tokenizer_revision=tokenizer_revision,
seed=seed,
gpu_memory_utilization=gpu_memory_utilization,
swap_space=swap_space,
enforce_eager=enforce_eager,
max_context_len_to_capture=max_context_len_to_capture,
disable_custom_all_reduce=disable_custom_all_reduce,
**kwargs,
)
self.llm_engine = LLMEngine.from_engine_args(engine_args)
self.request_counter = Counter()
def get_tokenizer(
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
return self.llm_engine.tokenizer.tokenizer
def set_tokenizer(
self,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer
def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
prefix_pos: Optional[Union[int, List[int]]] = None,
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
prefix_pos: If not None, we use the given position as the prefix
position for each prompt. We will cache the prefix's KV
cache and reuse it for the next request with the same prefix.
This is an experimental feature, and may be replaced with
automatic prefix caching in the future.
use_tqdm: Whether to use tqdm to display the progress bar.
lora_request: LoRA request to use for generation, if any.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if (prompts is not None and prompt_token_ids is not None
and len(prompts) != len(prompt_token_ids)):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the engine.
num_requests = len(prompts) if prompts is not None else len(
prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
token_ids = None if prompt_token_ids is None else prompt_token_ids[
i]
self._add_request(prompt,
sampling_params,
token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos_i)
return self._run_engine(use_tqdm)
def _add_request(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
) -> None:
request_id = str(next(self.request_counter))
self.llm_engine.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids,
lora_request=lora_request,
prefix_pos=prefix_pos)
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
# Initialize tqdm.
if use_tqdm:
num_requests = self.llm_engine.get_num_unfinished_requests()
pbar = tqdm(total=num_requests, desc="Processed prompts")
# Run the engine.
outputs: List[RequestOutput] = []
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
for output in step_outputs:
if output.finished:
outputs.append(output)
if use_tqdm:
pbar.update(1)
if use_tqdm:
pbar.close()
# Sort the outputs by request ID.
# This is necessary because some requests may be finished earlier than
# its previous requests.
outputs = sorted(outputs, key=lambda x: int(x.request_id))
return outputs

View File

View File

@@ -0,0 +1,251 @@
import argparse
import asyncio
import json
from contextlib import asynccontextmanager
import os
import importlib
import inspect
from prometheus_client import make_asgi_app
import fastapi
import uvicorn
from http import HTTPStatus
from fastapi import Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse, Response
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
from vllm.logger import init_logger
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_engine import LoRA
TIMEOUT_KEEP_ALIVE = 5 # seconds
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
logger = init_logger(__name__)
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
if not engine_args.disable_log_stats:
asyncio.create_task(_force_log())
yield
app = fastapi.FastAPI(lifespan=lifespan)
class LoRAParserAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
lora_list = []
for item in values:
name, path = item.split('=')
lora_list.append(LoRA(name, path))
setattr(namespace, self.dest, lora_list)
def parse_args():
parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
parser.add_argument("--host", type=str, default=None, help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument("--allow-credentials",
action="store_true",
help="allow credentials")
parser.add_argument("--allowed-origins",
type=json.loads,
default=["*"],
help="allowed origins")
parser.add_argument("--allowed-methods",
type=json.loads,
default=["*"],
help="allowed methods")
parser.add_argument("--allowed-headers",
type=json.loads,
default=["*"],
help="allowed headers")
parser.add_argument(
"--api-key",
type=str,
default=None,
help=
"If provided, the server will require this key to be presented in the header."
)
parser.add_argument("--served-model-name",
type=str,
default=None,
help="The model name used in the API. If not "
"specified, the model name will be the same as "
"the huggingface name.")
parser.add_argument(
"--lora-modules",
type=str,
default=None,
nargs='+',
action=LoRAParserAction,
help=
"LoRA module configurations in the format name=path. Multiple modules can be specified."
)
parser.add_argument("--chat-template",
type=str,
default=None,
help="The file path to the chat template, "
"or the template in single-line form "
"for the specified model")
parser.add_argument("--response-role",
type=str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
parser.add_argument("--ssl-keyfile",
type=str,
default=None,
help="The file path to the SSL key file")
parser.add_argument("--ssl-certfile",
type=str,
default=None,
help="The file path to the SSL cert file")
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser.add_argument(
"--middleware",
type=str,
action="append",
default=[],
help="Additional ASGI middleware to apply to the app. "
"We accept multiple --middleware arguments. "
"The value should be an import path. "
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
)
parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args()
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
@app.get("/health")
async def health() -> Response:
"""Health check."""
return Response(status_code=200)
@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
return JSONResponse(content=models.model_dump())
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
@app.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator,
media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
if __name__ == "__main__":
args = parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
@app.middleware("http")
async def authentication(request: Request, call_next):
if not request.url.path.startswith("/v1"):
return await call_next(request)
if request.headers.get("Authorization") != "Bearer " + token:
return JSONResponse(content={"error": "Unauthorized"},
status_code=401)
return await call_next(request)
for middleware in args.middleware:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
raise ValueError(
f"Invalid middleware {middleware}. Must be a function or a class."
)
logger.info(f"args: {args}")
if args.served_model_name is not None:
served_model = args.served_model_name
else:
served_model = args.model
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
openai_serving_chat = OpenAIServingChat(engine, served_model,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, served_model, args.lora_modules)
app.root_path = args.root_path
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile)

View File

@@ -0,0 +1,323 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
import time
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, model_validator
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
import torch
class ErrorResponse(BaseModel):
object: str = "error"
message: str
type: str
param: Optional[str] = None
code: int
class ModelPermission(BaseModel):
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
object: str = "model_permission"
created: int = Field(default_factory=lambda: int(time.time()))
allow_create_engine: bool = False
allow_sampling: bool = True
allow_logprobs: bool = True
allow_search_indices: bool = False
allow_view: bool = True
allow_fine_tuning: bool = False
organization: str = "*"
group: Optional[str] = None
is_blocking: str = False
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "vllm"
root: Optional[str] = None
parent: Optional[str] = None
permission: List[ModelPermission] = Field(default_factory=list)
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = Field(default_factory=list)
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class ChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = 0.7
top_p: Optional[float] = 1.0
n: Optional[int] = 1
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
stream: Optional[bool] = False
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
best_of: Optional[int] = None
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
add_generation_prompt: Optional[bool] = True
echo: Optional[bool] = False
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self) -> SamplingParams:
if self.logprobs and not self.top_logprobs:
raise ValueError("Top logprobs must be set when logprobs is.")
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class CompletionRequest(BaseModel):
model: str
# a string, array of strings, array of tokens, or array of token arrays
prompt: Union[List[int], List[List[int]], str, List[str]]
suffix: Optional[str] = None
max_tokens: Optional[int] = 16
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
n: Optional[int] = 1
stream: Optional[bool] = False
logprobs: Optional[int] = None
echo: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
seed: Optional[int] = None
presence_penalty: Optional[float] = 0.0
frequency_penalty: Optional[float] = 0.0
best_of: Optional[int] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Additional parameters supported by vLLM
top_k: Optional[int] = -1
ignore_eos: Optional[bool] = False
use_beam_search: Optional[bool] = False
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
repetition_penalty: Optional[float] = 1.0
min_p: Optional[float] = 0.0
include_stop_str_in_output: Optional[bool] = False
length_penalty: Optional[float] = 1.0
guided_json: Optional[Union[str, dict, BaseModel]] = None
guided_regex: Optional[str] = None
guided_choice: Optional[List[str]] = None
def to_sampling_params(self):
echo_without_generation = self.echo and self.max_tokens == 0
logits_processors = None
if self.logit_bias:
def logit_bias_logits_processor(
token_ids: List[int],
logits: torch.Tensor) -> torch.Tensor:
for token_id, bias in self.logit_bias.items():
# Clamp the bias between -100 and 100 per OpenAI API spec
bias = min(100, max(-100, bias))
logits[int(token_id)] += bias
return logits
logits_processors = [logit_bias_logits_processor]
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
)
@model_validator(mode="before")
@classmethod
def check_guided_decoding_count(cls, data):
guide_count = sum([
"guided_json" in data and data["guided_json"] is not None,
"guided_regex" in data and data["guided_regex"] is not None,
"guided_choice" in data and data["guided_choice"] is not None
])
if guide_count > 1:
raise ValueError(
"You can only use one kind of guided decoding "
"('guided_json', 'guided_regex' or 'guided_choice').")
return data
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[Optional[float]] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
class CompletionResponseChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseChoice]
usage: UsageInfo
class CompletionResponseStreamChoice(BaseModel):
index: int
text: str
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class CompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "text_completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
class ChatMessage(BaseModel):
role: str
content: str
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length"]] = None
class ChatCompletionStreamResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
object: str = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)

View File

@@ -0,0 +1,307 @@
import time
import codecs
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
UsageInfo)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
response_role: str,
lora_modules: Optional[List[LoRA]] = None,
chat_template=None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
self.response_role = response_role
self._load_chat_template(chat_template)
async def create_chat_completion(
self, request: ChatCompletionRequest, raw_request: Request
) -> Union[ErrorResponse, AsyncGenerator[str, None],
ChatCompletionResponse]:
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/chat/create
for the API specification. This API mimics the OpenAI ChatCompletion API.
NOTE: Currently we do not support the following feature:
- function_call (Users should implement this by themselves)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
try:
prompt = self.tokenizer.apply_chat_template(
conversation=request.messages,
tokenize=False,
add_generation_prompt=request.add_generation_prompt)
except Exception as e:
logger.error(
f"Error in applying chat template from request: {str(e)}")
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
except ValueError as e:
return self.create_error_response(str(e))
result_generator = self.engine.generate(prompt, sampling_params,
request_id, token_ids,
lora_request)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id)
else:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id)
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
if request.add_generation_prompt:
return self.response_role
else:
return request.messages[-1]["role"]
async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
model_name = request.model
created_time = int(time.monotonic())
chunk_object_type = "chat.completion.chunk"
# Send first response for each request.n (index) with the role
role = self.get_chat_request_role(request)
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
logprobs=None,
finish_reason=None)
chunk = ChatCompletionStreamResponse(id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the last message
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
if last_msg_content:
for i in range(request.n):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=last_msg_content),
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
# Send response for each token for each request.n (index)
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
if finish_reason_sent[i]:
continue
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
delta_text = output.text[len(previous_texts[i]):]
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
if output.finish_reason is None:
# Send token-by-token response for each request.n
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=None)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
data = chunk.model_dump_json(exclude_unset=True)
yield f"data: {data}\n\n"
else:
# Send the finish response for each request.n only once
prompt_tokens = len(res.prompt_token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=previous_num_tokens[i],
total_tokens=prompt_tokens + previous_num_tokens[i],
)
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
logprobs=logprobs,
finish_reason=output.finish_reason)
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
model=model_name)
if final_usage is not None:
chunk.usage = final_usage
data = chunk.model_dump_json(exclude_unset=True,
exclude_none=True)
yield f"data: {data}\n\n"
finish_reason_sent[i] = True
# Send the final done message after all response.n are finished
yield "data: [DONE]\n\n"
async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Request,
result_generator: AsyncIterator[RequestOutput],
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
model_name = request.model
created_time = int(time.monotonic())
final_res: RequestOutput = None
async for res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(request_id)
return self.create_error_response("Client disconnected")
final_res = res
assert final_res is not None
choices = []
role = self.get_chat_request_role(request)
for output in final_res.outputs:
token_ids = output.token_ids
top_logprobs = output.logprobs
if request.logprobs:
logprobs = self._create_logprobs(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = ChatCompletionResponseChoice(
index=output.index,
message=ChatMessage(role=role, content=output.text),
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
if request.echo:
last_msg_content = ""
if request.messages and isinstance(
request.messages, list) and request.messages[-1].get(
"content") and request.messages[-1].get(
"role") == role:
last_msg_content = request.messages[-1]["content"]
for choice in choices:
full_message = last_msg_content + choice.message.content
choice.message.content = full_message
num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = ChatCompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
return response
def _load_chat_template(self, chat_template):
if chat_template is not None:
try:
with open(chat_template, "r") as f:
self.tokenizer.chat_template = f.read()
except OSError:
# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
self.tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")
logger.info(
f"Using supplied chat template:\n{self.tokenizer.chat_template}"
)
elif self.tokenizer.chat_template is not None:
logger.info(
f"Using default chat template:\n{self.tokenizer.chat_template}"
)
else:
logger.warning(
"No chat template provided. Chat API will not work.")

View File

@@ -0,0 +1,361 @@
import asyncio
import time
from fastapi import Request
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
from vllm.logger import init_logger
from vllm.utils import random_uuid
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
LogProbs,
UsageInfo,
)
from vllm.outputs import RequestOutput
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
logger = init_logger(__name__)
TypeTokenIDs = List[int]
TypeTopLogProbs = List[Optional[Dict[int, float]]]
TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
async def completion_stream_generator(
request: CompletionRequest,
raw_request: Request,
on_abort,
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
num_prompts: int,
) -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
async for prompt_idx, res in result_generator:
# Abort the request if the client disconnects.
if await raw_request.is_disconnected():
await on_abort(f"{request_id}-{prompt_idx}")
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
if request.echo and request.max_tokens == 0:
# only return the prompt
delta_text = res.prompt
delta_token_ids = res.prompt_token_ids
top_logprobs = res.prompt_logprobs
has_echoed[i] = True
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
# echo the prompt and first token
delta_text = res.prompt + output.text
delta_token_ids = res.prompt_token_ids + output.token_ids
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
has_echoed[i] = True
else:
# return just the delta
delta_text = output.text[len(previous_texts[i]):]
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
top_logprobs = output.logprobs[
previous_num_tokens[i]:] if output.logprobs else None
if request.logprobs is not None:
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
logprobs = create_logprobs_fn(
token_ids=delta_token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
initial_text_offset=len(previous_texts[i]),
)
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
finish_reason = output.finish_reason
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=logprobs,
finish_reason=finish_reason,
)
]).model_dump_json()
yield f"data: {response_json}\n\n"
if output.finish_reason is not None: # return final usage
logprobs = LogProbs() if request.logprobs is not None else None
prompt_tokens = len(res.prompt_token_ids)
completion_tokens = len(output.token_ids)
final_usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
response_json = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[
CompletionResponseStreamChoice(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
],
usage=final_usage,
).model_dump_json()
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError(
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
)
return prompt_is_tokens, prompts
def request_output_to_completion_response(
final_res_batch: List[RequestOutput],
request: CompletionRequest,
create_logprobs_fn: TypeCreateLogProbsFn,
request_id: str,
created_time: int,
model_name: str,
) -> CompletionResponse:
choices = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt
for output in final_res.outputs:
if request.echo and request.max_tokens == 0:
token_ids = prompt_token_ids
top_logprobs = prompt_logprobs
output_text = prompt_text
elif request.echo and request.max_tokens > 0:
token_ids = prompt_token_ids + output.token_ids
top_logprobs = prompt_logprobs + output.logprobs
output_text = prompt_text + output.text
else:
token_ids = output.token_ids
top_logprobs = output.logprobs
output_text = output.text
if request.logprobs is not None:
logprobs = create_logprobs_fn(
token_ids=token_ids,
top_logprobs=top_logprobs,
num_output_top_logprobs=request.logprobs,
)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=len(choices),
text=output_text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)
num_prompt_tokens += len(prompt_token_ids)
num_generated_tokens += sum(
len(output.token_ids) for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
return CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)
def merge_async_iterators(*iterators):
"""Merge multiple asynchronous iterators into a single iterator.
This method handle the case where some iterators finish before others.
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue = asyncio.Queue()
finished = [False] * len(iterators)
async def producer(i, iterator):
async for item in iterator:
await queue.put((i, item))
finished[i] = True
_tasks = [
asyncio.create_task(producer(i, iterator))
for i, iterator in enumerate(iterators)
]
async def consumer():
while not all(finished) or not queue.empty():
item = await queue.get()
yield item
await asyncio.gather(*_tasks)
return consumer()
class OpenAIServingCompletion(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules: Optional[List[LoRA]] = None):
super().__init__(engine=engine,
served_model=served_model,
lora_modules=lora_modules)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
"""Completion API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/completions/create
for the API specification. This API mimics the OpenAI Completion API.
NOTE: Currently we do not support the following feature:
- suffix (the language models we currently support do not support
suffix)
"""
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
# Return error for unsupported features.
if request.suffix is not None:
return self.create_error_response(
"suffix is not currently supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
generators.append(
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
return self.create_error_response(str(e))
result_generator: AsyncIterator[Tuple[
int, RequestOutput]] = merge_async_iterators(*generators)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)
# Streaming response
if stream:
return completion_stream_generator(request,
raw_request,
self.engine.abort,
result_generator,
self._create_logprobs,
request_id,
created_time,
model_name,
num_prompts=len(prompts))
# Non-streaming response
final_res_batch: RequestOutput = [None] * len(prompts)
async for i, res in result_generator:
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
response = request_output_to_completion_response(
final_res_batch, request, self._create_logprobs, request_id,
created_time, model_name)
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
if request.stream:
response_json = response.model_dump_json()
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return fake_stream_generator()
return response

View File

@@ -0,0 +1,172 @@
import asyncio
from dataclasses import dataclass
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import get_tokenizer
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (CompletionRequest,
ChatCompletionRequest,
ErrorResponse, LogProbs,
ModelCard, ModelList,
ModelPermission)
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
@dataclass
class LoRA:
name: str
local_path: str
class OpenAIServing:
def __init__(self,
engine: AsyncLLMEngine,
served_model: str,
lora_modules=Optional[List[LoRA]]):
self.engine = engine
self.served_model = served_model
if lora_modules is None:
self.lora_requests = []
else:
self.lora_requests = [
LoRARequest(
lora_name=lora.name,
lora_int_id=i,
lora_local_path=lora.local_path,
) for i, lora in enumerate(lora_modules, start=1)
]
self.max_model_len = 0
self.tokenizer = None
try:
event_loop = asyncio.get_running_loop()
except RuntimeError:
event_loop = None
if event_loop is not None and event_loop.is_running(
): # If the current is instanced by Ray Serve, there is already a running event loop
event_loop.create_task(self._post_init())
else: # When using single vLLM without engine_use_ray
asyncio.run(self._post_init())
async def _post_init(self):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len
# A separate tokenizer to map token IDs to strings.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code)
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=self.served_model,
root=self.served_model,
permission=[ModelPermission()])
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=self.served_model,
permission=[ModelPermission()])
for lora in self.lora_requests
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
def _create_logprobs(
self,
token_ids: List[int],
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
num_output_top_logprobs: Optional[int] = None,
initial_text_offset: int = 0,
) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
if num_output_top_logprobs:
logprobs.top_logprobs = []
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is not None:
token_logprob = step_top_logprobs[token_id]
else:
token_logprob = None
token = self.tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(token_logprob)
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] +
last_token_len)
last_token_len = len(token)
if num_output_top_logprobs:
logprobs.top_logprobs.append({
self.tokenizer.convert_ids_to_tokens(i): p
for i, p in step_top_logprobs.items()
} if step_top_logprobs else None)
return logprobs
def create_error_response(
self,
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)
async def _check_model(self, request) -> Optional[ErrorResponse]:
if request.model == self.served_model:
return
if request.model in [lora.lora_name for lora in self.lora_requests]:
return
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
if request.model == self.served_model:
return
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
token_num = len(input_ids)
if request.max_tokens is None:
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
raise ValueError(
f"This model's maximum context length is {self.max_model_len} tokens. "
f"However, you requested {request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
else:
return input_ids