From 87260b7bfd7c46cfb4511024b44bc9fc43073ad5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E8=AF=91=E6=96=87?= <1020030101@qq.com> Date: Sat, 8 Jun 2024 03:24:28 +0800 Subject: [PATCH] Litellm Backend (#502) --- python/pyproject.toml | 5 +- python/sglang/__init__.py | 2 + python/sglang/backend/litellm.py | 89 +++++++++++++++++++++++++++++++ python/sglang/lang/ir.py | 15 ++++++ test/lang/test_litellm_backend.py | 27 ++++++++++ 5 files changed, 136 insertions(+), 2 deletions(-) create mode 100644 python/sglang/backend/litellm.py create mode 100644 test/lang/test_litellm_backend.py diff --git a/python/pyproject.toml b/python/pyproject.toml index 247c4e5cc..343b555f3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" version = "0.1.16" -description = "A structured generation langauge for LLMs." +description = "A structured generation langauge for LLMs." readme = "README.md" requires-python = ">=3.8" license = {file = "LICENSE"} @@ -23,7 +23,8 @@ srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn", "zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"] openai = ["openai>=1.0", "numpy", "tiktoken"] anthropic = ["anthropic>=0.20.0", "numpy"] -all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"] +litellm = ["litellm>=1.0.0"] +all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index 0d3010bbe..fd2bbe14d 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -27,6 +27,7 @@ from sglang.backend.anthropic import Anthropic from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.vertexai import VertexAI +from sglang.backend.litellm import LiteLLM # Global Configurations from sglang.global_config import global_config @@ -35,6 +36,7 @@ from sglang.global_config import global_config __all__ = [ "global_config", "Anthropic", + "LiteLLM", "OpenAI", "RuntimeEndpoint", "VertexAI", diff --git a/python/sglang/backend/litellm.py b/python/sglang/backend/litellm.py new file mode 100644 index 000000000..9a0060f33 --- /dev/null +++ b/python/sglang/backend/litellm.py @@ -0,0 +1,89 @@ +from typing import Mapping, Optional + +from sglang.backend.base_backend import BaseBackend +from sglang.lang.chat_template import get_chat_template_by_model_path +from sglang.lang.interpreter import StreamExecutor +from sglang.lang.ir import SglSamplingParams + +try: + import litellm +except ImportError as e: + litellm = e + + +class LiteLLM(BaseBackend): + + def __init__( + self, + model_name, + chat_template=None, + api_key=None, + organization: Optional[str] = None, + base_url: Optional[str] = None, + timeout: Optional[float] = 600, + max_retries: Optional[int] = litellm.num_retries, + default_headers: Optional[Mapping[str, str]] = None, + ): + super().__init__() + + if isinstance(litellm, Exception): + raise litellm + + self.model_name = model_name + + self.chat_template = chat_template or get_chat_template_by_model_path( + model_name) + + self.client_params = { + "api_key": api_key, + "organization": organization, + "base_url": base_url, + "timeout": timeout, + "max_retries": max_retries, + "default_headers": default_headers, + } + + def get_chat_template(self): + return self.chat_template + + def generate( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + ret = litellm.completion( + model=self.model_name, + messages=messages, + **self.client_params, + **sampling_params.to_anthropic_kwargs(), + ) + comp = ret.choices[0].message.content + + return comp, {} + + def generate_stream( + self, + s: StreamExecutor, + sampling_params: SglSamplingParams, + ): + if s.messages_: + messages = s.messages_ + else: + messages = [{"role": "user", "content": s.text_}] + + ret = litellm.completion( + model=self.model_name, + messages=messages, + stream=True, + **self.client_params, + **sampling_params.to_litellm_kwargs(), + ) + for chunk in ret: + text = chunk.choices[0].delta.content + if text is not None: + yield text, {} diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 2265a0a7a..c2b041fe3 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -81,6 +81,21 @@ class SglSamplingParams: "top_p": self.top_p, "top_k": self.top_k, } + + def to_litellm_kwargs(self): + if self.regex is not None: + warnings.warn( + "Regular expression is not supported in the LiteLLM backend." + ) + return { + "max_tokens": self.max_new_tokens, + "stop": self.stop or None, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + } def to_srt_kwargs(self): return { diff --git a/test/lang/test_litellm_backend.py b/test/lang/test_litellm_backend.py new file mode 100644 index 000000000..15d83bd51 --- /dev/null +++ b/test/lang/test_litellm_backend.py @@ -0,0 +1,27 @@ +import json +import unittest + +from sglang import LiteLLM, set_default_backend +from sglang.test.test_programs import test_mt_bench, test_stream + + +class TestAnthropicBackend(unittest.TestCase): + backend = None + chat_backend = None + + def setUp(self): + cls = type(self) + + if cls.backend is None: + cls.backend = LiteLLM("gpt-3.5-turbo") + set_default_backend(cls.backend) + + def test_mt_bench(self): + test_mt_bench() + + def test_stream(self): + test_stream() + + +if __name__ == "__main__": + unittest.main(warnings="ignore")