From 48efec7b052354865aa2f0605a5bf778721f3cbb Mon Sep 17 00:00:00 2001 From: woodx <124784234+woodx9@users.noreply.github.com> Date: Mon, 17 Mar 2025 09:26:19 +0800 Subject: [PATCH] Feature: support code completion (#3612) --- python/sglang/srt/code_completion_parser.py | 174 ++++++++++++++++++++ python/sglang/srt/entrypoints/engine.py | 4 + python/sglang/srt/openai_api/adapter.py | 10 +- python/sglang/srt/server_args.py | 7 + test/srt/run_suite.py | 1 + test/srt/test_fim_completion.py | 71 ++++++++ 6 files changed, 266 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/code_completion_parser.py create mode 100644 test/srt/test_fim_completion.py diff --git a/python/sglang/srt/code_completion_parser.py b/python/sglang/srt/code_completion_parser.py new file mode 100644 index 000000000..94a98b0fd --- /dev/null +++ b/python/sglang/srt/code_completion_parser.py @@ -0,0 +1,174 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Completion templates.""" + + +import dataclasses +import json +import logging +import os +from enum import auto + +from sglang.srt.openai_api.protocol import ChatCompletionRequest + +logger = logging.getLogger(__name__) +completion_template_name = None + + +class FimPosition: + """Postion of fim middle token.""" + + MIDDLE = auto() + END = auto() + + +@dataclasses.dataclass +class CompletionTemplate: + """A class that manages completion prompt templates. only for code completion currently.""" + + # The name of this template + name: str + + # the fim begin token + fim_begin_token: str + + # The fim middle token + fim_middle_token: str + + # The fim end token + fim_end_token: str + + # The position of the fim middle token + fim_position: FimPosition + + +# A global registry for all completion templates +completion_templates: dict[str, CompletionTemplate] = {} + + +def load_completion_template_for_openai_api(completion_template_arg): + global completion_template_name + + logger.info( + f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}" + ) + + if not completion_template_exists(completion_template_arg): + if not os.path.exists(completion_template_arg): + raise RuntimeError( + f"Completion template {completion_template_arg} is not a built-in template name " + "or a valid completion template file path." + ) + + assert completion_template_arg.endswith( + ".json" + ), "unrecognized format of completion template file" + with open(completion_template_arg, "r") as filep: + template = json.load(filep) + try: + fim_position = FimPosition[template["fim_position"]] + except KeyError: + raise ValueError( + f"Unknown fim position: {template['fim_position']}" + ) from None + register_completion_template( + CompletionTemplate( + name=template["name"], + fim_begin_token=template["fim_begin_token"], + fim_middle_token=template["fim_middle_token"], + fim_end_token=template["fim_end_token"], + fim_position=fim_position, + ), + override=True, + ) + completion_template_name = template["name"] + else: + completion_template_name = completion_template_arg + + +def register_completion_template(template: CompletionTemplate, override: bool = False): + """Register a new completion template.""" + if not override: + assert ( + template.name not in completion_templates + ), f"{template.name} has been registered." + + completion_templates[template.name] = template + + +def completion_template_exists(template_name: str) -> bool: + return template_name in completion_templates + + +def is_completion_template_defined() -> bool: + global completion_template_name + return completion_template_name != None + + +def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str: + global completion_template_name + if request.suffix == "": + return request.prompt + + return generate_completion_prompt( + request.prompt, request.suffix, completion_template_name + ) + + +def generate_completion_prompt(prompt: str, suffix: str, template_name: str) -> str: + + completion_template = completion_templates[template_name] + fim_begin_token = completion_template.fim_begin_token + fim_middle_token = completion_template.fim_middle_token + fim_end_token = completion_template.fim_end_token + fim_position = completion_template.fim_position + + if fim_position == FimPosition.MIDDLE: + prompt = f"{fim_begin_token}{prompt}{fim_middle_token}{suffix}{fim_end_token}" + elif fim_position == FimPosition.END: + prompt = f"{fim_begin_token}{prompt}{fim_end_token}{suffix}{fim_middle_token}" + + return prompt + + +register_completion_template( + CompletionTemplate( + name="deepseek_coder", + fim_begin_token="<|fim▁begin|>", + fim_middle_token="<|fim▁hole|>", + fim_end_token="<|fim▁end|>", + fim_position=FimPosition.MIDDLE, + ) +) + + +register_completion_template( + CompletionTemplate( + name="star_coder", + fim_begin_token="", + fim_middle_token="", + fim_end_token="", + fim_position=FimPosition.END, + ) +) + +register_completion_template( + CompletionTemplate( + name="qwen_coder", + fim_begin_token="<|fim_prefix|>", + fim_middle_token="<|fim_middle|>", + fim_end_token="<|fim_suffix|>", + fim_position=FimPosition.END, + ) +) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index ec4ea515b..8fc792b6c 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -36,6 +36,7 @@ setattr(threading, "_register_atexit", lambda *args, **kwargs: None) import torch import uvloop +from sglang.srt.code_completion_parser import load_completion_template_for_openai_api from sglang.srt.managers.data_parallel_controller import ( run_data_parallel_controller_process, ) @@ -538,6 +539,9 @@ def _launch_subprocesses( tokenizer_manager, server_args.chat_template, server_args.model_path ) + if server_args.completion_template: + load_completion_template_for_openai_api(server_args.completion_template) + # Wait for the model to finish loading scheduler_infos = [] for i in range(len(scheduler_pipe_readers)): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 2ac4e3ed8..ba4ca7b4f 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -33,6 +33,10 @@ except ImportError: # outlines.integrations.utils from outlines.integrations.utils import convert_json_schema_to_str +from sglang.srt.code_completion_parser import ( + generate_completion_prompt_from_request, + is_completion_template_defined, +) from sglang.srt.conversation import ( Conversation, SeparatorStyle, @@ -504,7 +508,11 @@ def v1_generate_request( "To compute logprobs of input prompt, please use the native /generate API." ) - prompts.append(request.prompt) + prompt = request.prompt + if is_completion_template_defined(): + prompt = generate_completion_prompt_from_request(request) + prompts.append(prompt) + lora_paths.append(request.lora_path) if request.echo and request.logprobs: current_logprob_start_len = 0 diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 545cb62e5..2901ba57e 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -56,6 +56,7 @@ class ServerArgs: device: Optional[str] = None served_model_name: Optional[str] = None chat_template: Optional[str] = None + completion_template: Optional[str] = None is_embedding: bool = False revision: Optional[str] = None @@ -456,6 +457,12 @@ class ServerArgs: default=ServerArgs.chat_template, help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.", ) + parser.add_argument( + "--completion-template", + type=str, + default=ServerArgs.completion_template, + help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.", + ) parser.add_argument( "--is-embedding", action="store_true", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5d386afef..dbf7a7b84 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -70,6 +70,7 @@ suites = { TestFile("test_vision_chunked_prefill.py", 223), TestFile("test_vision_llm.py", 18.4), TestFile("test_vision_openai_server.py", 344), + TestFile("test_fim_completion.py", 120), TestFile("test_w8a8_quantization.py", 46), TestFile("test_eval_fp8_accuracy.py", 172), TestFile("test_create_kvindices.py", 2), diff --git a/test/srt/test_fim_completion.py b/test/srt/test_fim_completion.py new file mode 100644 index 000000000..132911e65 --- /dev/null +++ b/test/srt/test_fim_completion.py @@ -0,0 +1,71 @@ +import unittest + +import openai + +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestFimCompletion(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "deepseek-ai/deepseek-coder-1.3b-base" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + other_args = ["--completion-template", "deepseek_coder"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + cls.tokenizer = get_tokenizer(cls.model) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_fim_completion(self, number_of_completion): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + prompt = "function sum(a: number, b: number): number{\n" + suffix = "}" + + prompt_input = self.tokenizer.encode(prompt) + self.tokenizer.encode(suffix) + num_prompt_tokens = len(prompt_input) + 2 + + response = client.completions.create( + model=self.model, + prompt=prompt, + suffix=suffix, + temperature=0.3, + max_tokens=32, + stream=False, + n=number_of_completion, + ) + + print(response) + print(len(response.choices)) + assert len(response.choices) == number_of_completion + assert response.id + assert response.created + assert response.object == "text_completion" + assert ( + response.usage.prompt_tokens == num_prompt_tokens + ), f"{response.usage.prompt_tokens} vs {num_prompt_tokens}" + assert response.usage.completion_tokens > 0 + assert response.usage.total_tokens > 0 + + def test_fim_completion(self): + for number_of_completion in [1, 3]: + self.run_fim_completion(number_of_completion) + + +if __name__ == "__main__": + unittest.main()