diff --git a/README.md b/README.md index eb0bb485f..3b3fceb36 100644 --- a/README.md +++ b/README.md @@ -248,6 +248,8 @@ In addition, the server supports an experimental OpenAI-compatible API. import openai client = openai.Client( base_url="http://127.0.0.1:30000/v1", api_key="EMPTY") + +# Text completion response = client.completions.create( model="default", prompt="The capital of France is", @@ -255,6 +257,46 @@ response = client.completions.create( max_tokens=32, ) print(response) + +# Chat completion +response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0, + max_tokens=64, +) +print(response) +``` + +In above example, the server uses the chat template specified in the model tokenizer. +You can override the chat template if needed when launching the server: + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +--chat-template llama-2 +``` + +If the chat template you are looking for is missing, you are welcome to contribute it. +Meanwhile, you can also temporary register your chat template as follows: + +```json +{ + "name": "my_model", + "system": "<|im_start|>system", + "user": "<|im_start|>user", + "assistant": "<|im_start|>assistant", + "sep_style": "CHATML", + "sep": "<|im_end|>", + "stop_str": ["<|im_end|>", "<|im_start|>"] +} +``` + +``` +python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 +--chat-template ./my_model_template.json ``` ### Additional Arguments diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py new file mode 100644 index 000000000..41d153fd4 --- /dev/null +++ b/python/sglang/srt/conversation.py @@ -0,0 +1,381 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py +from sglang.srt.managers.openai_protocol import ChatCompletionRequest +from enum import IntEnum, auto +import dataclasses +from typing import Dict, List, Tuple, Union + + +class SeparatorStyle(IntEnum): + """Separator styles.""" + + ADD_COLON_SINGLE = auto() + ADD_COLON_TWO = auto() + ADD_COLON_SPACE_SINGLE = auto() + NO_COLON_SINGLE = auto() + NO_COLON_TWO = auto() + ADD_NEW_LINE_SINGLE = auto() + LLAMA2 = auto() + CHATGLM = auto() + CHATML = auto() + CHATINTERN = auto() + DOLLY = auto() + RWKV = auto() + PHOENIX = auto() + ROBIN = auto() + FALCON_CHAT = auto() + CHATGLM3 = auto() + DEEPSEEK_CHAT = auto() + METAMATH = auto() + + +@dataclasses.dataclass +class Conversation: + """A class that manages prompt templates and keeps all conversation history.""" + + # The name of this template + name: str + # The template of the system prompt + system_template: str = "{system_message}" + # The system message + system_message: str = "" + # The names of two roles + roles: Tuple[str] = ("USER", "ASSISTANT") + # All messages. Each item is (role, message). + messages: List[List[str]] = () + # The number of few shot examples + offset: int = 0 + # The separator style and configurations + sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE + sep: str = "\n" + sep2: str = None + # Stop criteria (the default one is EOS token) + stop_str: Union[str, List[str]] = None + + def get_prompt(self) -> str: + """Get the prompt for generation.""" + system_prompt = self.system_template.format(system_message=self.system_message) + if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + seps[0] + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ": " # must be end with a space + return ret + elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: + ret = "" if system_prompt == "" else system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + message + self.sep + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.NO_COLON_TWO: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + message + seps[i % 2] + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.RWKV: + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message.replace("\r\n", "\n").replace("\n\n", "\n") + ret += "\n\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.LLAMA2: + seps = [self.sep, self.sep2] + if self.system_message: + ret = system_prompt + else: + ret = "[INST] " + for i, (role, message) in enumerate(self.messages): + tag = self.roles[i % 2] + if message: + if i == 0: + ret += message + " " + else: + ret += tag + " " + message + seps[i % 2] + else: + ret += tag + return ret + elif self.sep_style == SeparatorStyle.CHATGLM: + # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 + # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 + round_add_n = 1 if self.name == "chatglm2" else 0 + if system_prompt: + ret = system_prompt + self.sep + else: + ret = "" + + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += f"[Round {i//2 + round_add_n}]{self.sep}" + + if message: + ret += f"{role}:{message}{self.sep}" + else: + ret += f"{role}:" + return ret + elif self.sep_style == SeparatorStyle.CHATML: + ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" + for role, message in self.messages: + if message: + ret += role + "\n" + message + self.sep + "\n" + else: + ret += role + "\n" + return ret + elif self.sep_style == SeparatorStyle.CHATGLM3: + ret = "" + if self.system_message: + ret += system_prompt + for role, message in self.messages: + if message: + ret += role + "\n" + message + else: + ret += role + return ret + elif self.sep_style == SeparatorStyle.CHATINTERN: + # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if i % 2 == 0: + ret += "" + if message: + ret += role + ":" + message + seps[i % 2] + "\n" + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.DOLLY: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ":\n" + message + seps[i % 2] + if i % 2 == 1: + ret += "\n\n" + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.PHOENIX: + ret = system_prompt + for role, message in self.messages: + if message: + ret += role + ": " + "" + message + "" + else: + ret += role + ": " + "" + return ret + elif self.sep_style == SeparatorStyle.ROBIN: + ret = system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ":\n" + message + self.sep + else: + ret += role + ":\n" + return ret + elif self.sep_style == SeparatorStyle.FALCON_CHAT: + ret = "" + if self.system_message: + ret += system_prompt + self.sep + for role, message in self.messages: + if message: + ret += role + ": " + message + self.sep + else: + ret += role + ":" + return ret + elif self.sep_style == SeparatorStyle.METAMATH: + ret = "" if system_prompt == "" else system_prompt + self.sep + for i, (role, message) in enumerate(self.messages): + # For MetaMath, sep2 is used to prefix the message. + starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 + ending_sep = self.sep if i % 2 == 0 else "" + if message: + ret += role + starting_sep + message + ending_sep + else: + ret += role + starting_sep + return ret + elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: + seps = [self.sep, self.sep2] + ret = system_prompt + for i, (role, message) in enumerate(self.messages): + if message: + ret += role + ": " + message + seps[i % 2] + else: + ret += role + ":" + return ret + else: + raise ValueError(f"Invalid style: {self.sep_style}") + + def set_system_message(self, system_message: str): + """Set the system message.""" + self.system_message = system_message + + def append_message(self, role: str, message: str): + """Append a new message.""" + self.messages.append([role, message]) + + def update_last_message(self, message: str): + """Update the last output. + + The last message is typically set to be None when constructing the prompt, + so we need to update it in-place after getting the response from a model. + """ + self.messages[-1][1] = message + + def to_gradio_chatbot(self): + """Convert the conversation to gradio chatbot format.""" + ret = [] + for i, (role, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append([msg, None]) + else: + ret[-1][-1] = msg + return ret + + def to_openai_api_messages(self): + """Convert the conversation to OpenAI chat completion format.""" + if self.system_message == "": + ret = [] + else: + ret = [{"role": "system", "content": self.system_message}] + + for i, (_, msg) in enumerate(self.messages[self.offset :]): + if i % 2 == 0: + ret.append({"role": "user", "content": msg}) + else: + if msg is not None: + ret.append({"role": "assistant", "content": msg}) + return ret + + def copy(self): + return Conversation( + name=self.name, + system_template=self.system_template, + system_message=self.system_message, + roles=self.roles, + messages=[[x, y] for x, y in self.messages], + offset=self.offset, + sep_style=self.sep_style, + sep=self.sep, + sep2=self.sep2, + stop_str=self.stop_str, + ) + + def dict(self): + return { + "template_name": self.name, + "system_message": self.system_message, + "roles": self.roles, + "messages": self.messages, + "offset": self.offset, + } + + +# A global registry for all conversation templates +chat_templates: Dict[str, Conversation] = {} + + +def register_conv_template(template: Conversation, override: bool = False): + """Register a new conversation template.""" + if not override: + assert template.name not in chat_templates, f"{template.name} has been registered." + + chat_templates[template.name] = template + + +def chat_template_exists(template_name: str) -> bool: + return template_name in chat_templates + + +def generate_chat_conv(request: ChatCompletionRequest, template_name: str) -> Conversation: + conv = chat_templates[template_name].copy() + conv = Conversation( + name=conv.name, + system_template=conv.system_template, + system_message=conv.system_message, + roles=conv.roles, + messages=list(conv.messages), # prevent in-place modification + offset=conv.offset, + sep_style=SeparatorStyle(conv.sep_style), + sep=conv.sep, + sep2=conv.sep2, + stop_str=conv.stop_str, + ) + + if isinstance(request.messages, str): + raise ValueError("The messages should be a list of dict.") + for message in request.messages: + msg_role = message["role"] + if msg_role == "system": + conv.system_message = message["content"] + elif msg_role == "user": + conv.append_message(conv.roles[0], message["content"]) + elif msg_role == "assistant": + conv.append_message(conv.roles[1], message["content"]) + else: + raise ValueError(f"Unknown role: {msg_role}") + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + + return conv + + +# llama2 template +# reference: https://huggingface.co/blog/codellama#conversational-instructions +# reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 +register_conv_template( + Conversation( + name="llama-2", + system_template="[INST] <>\n{system_message}\n<>\n\n", + roles=("[INST]", "[/INST]"), + sep_style=SeparatorStyle.LLAMA2, + sep=" ", + sep2=" ", + stop_str=["[INST]", "[/INST]", "<>", "<>"], + ) +) + +register_conv_template( + Conversation( + name="chatml", + system_template="<|im_start|>system\n{system_message}", + system_message="You are an AI assistant.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep_style=SeparatorStyle.CHATML, + sep="<|im_end|>", + stop_str=["<|endoftext|>", "<|im_end|>"], + ) +) diff --git a/python/sglang/srt/managers/openai_protocol.py b/python/sglang/srt/managers/openai_protocol.py index e80b1441c..974e38a91 100644 --- a/python/sglang/srt/managers/openai_protocol.py +++ b/python/sglang/srt/managers/openai_protocol.py @@ -65,3 +65,59 @@ class CompletionStreamResponse(BaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[CompletionResponseStreamChoice] + usage: UsageInfo + + +class ChatCompletionRequest(BaseModel): + model: str + messages: Union[str, List[Dict[str, str]]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = 16 + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + best_of: Optional[int] = None + + +class ChatMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + finish_reason: Optional[str] = None + + +class ChatCompletionResponse(BaseModel): + id: str + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + finish_reason: Optional[str] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5a171aa9a..320c0e86a 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -2,6 +2,7 @@ import asyncio import json import multiprocessing as mp +import os import sys import threading import time @@ -17,15 +18,29 @@ import uvloop from fastapi import FastAPI, Request from fastapi.responses import StreamingResponse from sglang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.srt.conversation import ( + Conversation, + SeparatorStyle, + chat_template_exists, + generate_chat_conv, + register_conv_template, +) from sglang.srt.managers.detokenizer_manager import start_detokenizer_process from sglang.srt.managers.io_struct import GenerateReqInput from sglang.srt.managers.openai_protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionResponseChoice, + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatMessage, CompletionRequest, CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - UsageInfo + DeltaMessage, + UsageInfo, ) from sglang.srt.managers.router.manager import start_router_process from sglang.srt.managers.tokenizer_manager import TokenizerManager @@ -37,6 +52,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) app = FastAPI() tokenizer_manager = None +chat_template_name = None @app.get("/get_model_info") @@ -46,6 +62,7 @@ async def get_model_info(): } return result + async def stream_generator(obj): async for out in tokenizer_manager.generate_request(obj): yield out @@ -61,7 +78,7 @@ async def generate_request(obj: GenerateReqInput): async for out in stream_generator(obj): yield f"data: {json.dumps(out, ensure_ascii=False)}\n\n" yield "data: [DONE]\n\n" - + return StreamingResponse(stream_results(), media_type="text/event-stream") ret = await tokenizer_manager.generate_request(obj).__anext__() @@ -91,11 +108,15 @@ async def v1_completions(raw_request: Request): adapted_request.post_init() if adapted_request.stream: + async def gnerate_stream_resp(): stream_buffer = "" async for content in stream_generator(adapted_request): text = content["text"] - delta = text[len(stream_buffer):] + prompt_tokens = content["meta_info"]["prompt_tokens"] + completion_tokens = content["meta_info"]["completion_tokens"] + + delta = text[len(stream_buffer) :] stream_buffer = text choice_data = CompletionResponseStreamChoice( index=0, @@ -108,12 +129,17 @@ async def v1_completions(raw_request: Request): object="text_completion", choices=[choice_data], model=request.model, + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), ) yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") - # Non-streaming response. ret = await generate_request(adapted_request) @@ -121,7 +147,7 @@ async def v1_completions(raw_request: Request): index=0, text=ret["text"], logprobs=None, - finish_reason=None, # TODO(comaniac): Add finish reason. + finish_reason=None, # TODO(comaniac): Add finish reason. ) prompt_tokens = ret["meta_info"]["prompt_tokens"] @@ -139,8 +165,108 @@ async def v1_completions(raw_request: Request): return response +@app.post("/v1/chat/completions") +async def v1_chat_completions(raw_request: Request): + request_json = await raw_request.json() + request = ChatCompletionRequest(**request_json) + + # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid. + assert request.n == 1 + + if not isinstance(request.messages, str): + # Apply chat template and its stop strings. + if chat_template_name is None: + prompt = tokenizer_manager.tokenizer.apply_chat_template( + request.messages, tokenize=False, add_generation_prompt=True + ) + stop = request.stop + else: + conv = generate_chat_conv(request, chat_template_name) + prompt = conv.get_prompt() + stop = conv.stop_str or [] + if request.stop: + if isinstance(request.stop, str): + stop.append(request.stop) + else: + stop.extend(request.stop) + else: + # Use the raw prompt and stop strings if the messages is already a string. + prompt = request.messages + stop = request.stop + + adapted_request = GenerateReqInput( + text=prompt, + sampling_params={ + "temperature": request.temperature, + "max_new_tokens": request.max_tokens, + "stop": stop, + "top_p": request.top_p, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + }, + stream=request.stream, + ) + adapted_request.post_init() + + if adapted_request.stream: + + async def gnerate_stream_resp(): + is_first = True + + stream_buffer = "" + async for content in stream_generator(adapted_request): + if is_first: + # First chunk with role + is_first = False + choice_data = ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(role="assistant"), + finish_reason=None, + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], choices=[choice_data], model=request.model + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + + text = content["text"] + delta = text[len(stream_buffer) :] + stream_buffer = text + choice_data = ChatCompletionResponseStreamChoice( + index=0, delta=DeltaMessage(content=delta), finish_reason=None + ) + chunk = ChatCompletionStreamResponse( + id=content["meta_info"]["id"], choices=[choice_data], model=request.model + ) + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse(gnerate_stream_resp(), media_type="text/event-stream") + + # Non-streaming response. + ret = await generate_request(adapted_request) + prompt_tokens = ret["meta_info"]["prompt_tokens"] + completion_tokens = ret["meta_info"]["completion_tokens"] + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=ret["text"]), + finish_reason=None, # TODO(comaniac): Add finish reason. + ) + response = ChatCompletionResponse( + id=ret["meta_info"]["id"], + model=request.model, + choices=[choice_data], + usage=UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + return response + + def launch_server(server_args, pipe_finish_writer): global tokenizer_manager + global chat_template_name # Allocate ports can_use_ports = alloc_usable_network_port( @@ -154,6 +280,36 @@ def launch_server(server_args, pipe_finish_writer): model_rpc_ports=can_use_ports[4:], ) + # Load chat template if needed + if server_args.chat_template is not None: + if not chat_template_exists(server_args.chat_template): + if not os.path.exists(server_args.chat_template): + raise RuntimeError( + f"Chat template {server_args.chat_template} is not a built-in template name " + "or a valid chat template file path." + ) + with open(server_args.chat_template, "r") as filep: + template = json.load(filep) + try: + sep_style = SeparatorStyle[template["sep_style"]] + except KeyError: + raise ValueError(f"Unknown separator style: {template['sep_style']}") from None + register_conv_template( + Conversation( + name=template["name"], + system_template=template["system"] + "\n{system_message}", + system_message=template.get("system_message", ""), + roles=(template["user"], template["assistant"]), + sep_style=sep_style, + sep=template.get("sep", "\n"), + stop_str=template["stop_str"], + ), + override=True, + ) + chat_template_name = template["name"] + else: + chat_template_name = server_args.chat_template + # Launch processes tokenizer_manager = TokenizerManager(server_args, port_args) pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e0d1c236d..8348cf0bb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -11,6 +11,7 @@ class ServerArgs: port: int = 30000 load_format: str = "auto" tokenizer_mode: str = "auto" + chat_template: Optional[str] = None trust_remote_code: bool = True mem_fraction_static: Optional[float] = None tp_size: int = 1 @@ -77,6 +78,12 @@ class ServerArgs: "tokenizer if available, and 'slow' will " "always use the slow tokenizer.", ) + parser.add_argument( + "--chat-template", + type=str, + default=ServerArgs.chat_template, + help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server", + ) parser.add_argument( "--trust-remote-code", action="store_true", diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index f5db747ce..33d5b0672 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -1,8 +1,16 @@ """ -python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --port 30000 +First run the following command to launch the server. +Note that TinyLlama adopts different chat templates in different versions. +For v0.4, the chat template is chatml. -Output: -The capital of France is Paris.\nThe capital of the United States is Washington, D.C.\nThe capital of Canada is Ottawa.\nThe capital of Japan is Tokyo +python3 -m sglang.launch_server --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 \ +--port 30000 --chat-template chatml + +Output example: +The capital of France is Paris. +The capital of the United States is Washington, D.C. +The capital of Canada is Ottawa. +The capital of Japan is Tokyo """ import argparse @@ -38,13 +46,57 @@ def test_completion_stream(args): for r in response: print(r.choices[0].text, end="", flush=True) assert r.id - assert r.created assert r.usage.prompt_tokens > 0 assert r.usage.completion_tokens > 0 assert r.usage.total_tokens > 0 print() +def test_chat_completion(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "What is the capital of France?"}, + ], + temperature=0, + max_tokens=32, + ) + print(response.choices[0].message.content) + assert response.id + assert response.created + assert response.usage.prompt_tokens > 0 + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + +def test_chat_completion_stream(args): + client = openai.Client(api_key="EMPTY", base_url=args.base_url) + response = client.chat.completions.create( + model="default", + messages=[ + {"role": "system", "content": "You are a helpful AI assistant"}, + {"role": "user", "content": "List 3 countries and their capitals."}, + ], + temperature=0, + max_tokens=64, + stream=True, + ) + is_first = True + for chunk in response: + if is_first: + is_first = False + assert chunk.choices[0].delta.role == "assistant" + continue + + data = chunk.choices[0].delta + if not data.content: + continue + print(data.content, end="", flush=True) + print() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--base-url", type=str, default="http://127.0.0.1:30000/v1") @@ -52,3 +104,5 @@ if __name__ == "__main__": test_completion(args) test_completion_stream(args) + test_chat_completion(args) + test_chat_completion_stream(args)