import asyncio import codecs import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, Optional, Tuple, TypedDict, Union, final) from fastapi import Request from openai.types.chat import (ChatCompletionContentPartParam, ChatCompletionRole) from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.utils import random_uuid logger = init_logger(__name__) @final # So that it should be compatible with Dict[str, str] class ConversationMessage(TypedDict): role: str content: str class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules, await_post_init=self._load_chat_template( chat_template=chat_template)) self.response_role = response_role def _parse_chat_message_content( self, role: ChatCompletionRole, content: Optional[Union[str, Iterable[ChatCompletionContentPartParam]]], ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: if content is None: return [], [] if isinstance(content, str): return [ConversationMessage(role=role, content=content)], [] texts: List[str] = [] for _, part in enumerate(content): if part["type"] == "text": text = part["text"] texts.append(text) else: raise NotImplementedError(f"Unknown part type: {part['type']}") return [ConversationMessage(role=role, content="\n".join(texts))], [] async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ) -> Union[ErrorResponse, AsyncGenerator[str, None], ChatCompletionResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create for the API specification. This API mimics the OpenAI ChatCompletion API. NOTE: Currently we do not support the following feature: - function_call (Users should implement this by themselves) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret try: conversation: List[ConversationMessage] = [] for m in request.messages: messages, _ = self._parse_chat_message_content( m["role"], m["content"]) conversation.extend(messages) prompt = self.tokenizer.apply_chat_template( conversation=conversation, tokenize=False, add_generation_prompt=request.add_generation_prompt, ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) request_id = f"cmpl-{random_uuid()}" try: # Tokenize/detokenize depending on prompt format (string/token list) prompt_ids, prompt_text = self._validate_prompt_and_tokenize( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( guided_decoding_backend, request, await self.engine.get_tokenizer())) if guided_decode_logits_processor: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] sampling_params.logits_processors.append( guided_decode_logits_processor) except ValueError as e: return self.create_error_response(str(e)) result_generator = self.engine.generate(prompt_text, sampling_params, request_id, prompt_ids, lora_request) # Streaming response if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation) else: try: return await self.chat_completion_full_generator( request, raw_request, result_generator, request_id, conversation) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: return self.response_role else: return request.messages[-1]["role"] async def chat_completion_stream_generator( self, request: ChatCompletionRequest, result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage] ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True # Send response for each token for each request.n (index) assert request.n is not None previous_texts = [""] * request.n previous_num_tokens = [0] * request.n finish_reason_sent = [False] * request.n try: async for res in result_generator: # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). if first_iteration: # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) for i in range(request.n): choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(role=role), logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" # Send response to echo the input portion of the # last message if request.echo: last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: last_msg_content = conversation[-1]["content"] if last_msg_content: for i in range(request.n): choice_data = ( ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( content=last_msg_content), finish_reason=None)) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], logprobs=None, model=model_name) data = chunk.model_dump_json( exclude_unset=True) yield f"data: {data}\n\n" first_iteration = False for output in res.outputs: i = output.index if finish_reason_sent[i]: continue delta_token_ids = output.token_ids[previous_num_tokens[i]:] top_logprobs = output.logprobs[ previous_num_tokens[i]:] if output.logprobs else None if request.logprobs: logprobs = self._create_logprobs( token_ids=delta_token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, initial_text_offset=len(previous_texts[i]), ) else: logprobs = None delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text previous_num_tokens[i] = len(output.token_ids) if output.finish_reason is None: # Send token-by-token response for each request.n choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, finish_reason=None) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], model=model_name) data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" else: # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=previous_num_tokens[i], total_tokens=prompt_tokens + previous_num_tokens[i], ) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], model=model_name) if final_usage is not None: chunk.usage = final_usage data = chunk.model_dump_json(exclude_unset=True, exclude_none=True) yield f"data: {data}\n\n" finish_reason_sent[i] = True except ValueError as e: # TODO: Use a vllm-specific Validation Error data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished yield "data: [DONE]\n\n" async def chat_completion_full_generator( self, request: ChatCompletionRequest, raw_request: Request, result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage] ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] created_time = int(time.time()) final_res: Optional[RequestOutput] = None async for res in result_generator: if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(request_id) return self.create_error_response("Client disconnected") final_res = res assert final_res is not None choices = [] role = self.get_chat_request_role(request) for output in final_res.outputs: token_ids = output.token_ids top_logprobs = output.logprobs if request.logprobs: logprobs = self._create_logprobs( token_ids=token_ids, top_logprobs=top_logprobs, num_output_top_logprobs=request.logprobs, ) else: logprobs = None choice_data = ChatCompletionResponseChoice( index=output.index, message=ChatMessage(role=role, content=output.text), logprobs=logprobs, finish_reason=output.finish_reason, stop_reason=output.stop_reason, ) choices.append(choice_data) if request.echo: last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: last_msg_content = conversation[-1]["content"] for choice in choices: full_message = last_msg_content + choice.message.content choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) num_generated_tokens = sum( len(output.token_ids) for output in final_res.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, ) response = ChatCompletionResponse( id=request_id, created=created_time, model=model_name, choices=choices, usage=usage, ) return response async def _load_chat_template(self, chat_template: Optional[str]): while self.tokenizer is None: # Give the parent class time to load the tokenizer await asyncio.sleep(0.1) tokenizer = self.tokenizer if chat_template is not None: try: with open(chat_template, "r") as f: tokenizer.chat_template = f.read() except OSError as e: JINJA_CHARS = "{}\n" if not any(c in chat_template for c in JINJA_CHARS): msg = (f"The supplied chat template ({chat_template}) " f"looks like a file path, but it failed to be " f"opened. Reason: {e}") raise ValueError(msg) from e # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly tokenizer.chat_template = codecs.decode( chat_template, "unicode_escape") logger.info("Using supplied chat template:\n%s", tokenizer.chat_template) elif tokenizer.chat_template is not None: logger.info("Using default chat template:\n%s", tokenizer.chat_template) else: logger.warning( "No chat template provided. Chat API will not work.")