Litellm Backend (#502)
This commit is contained in:
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
[project]
|
[project]
|
||||||
name = "sglang"
|
name = "sglang"
|
||||||
version = "0.1.16"
|
version = "0.1.16"
|
||||||
description = "A structured generation langauge for LLMs."
|
description = "A structured generation langauge for LLMs."
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.8"
|
requires-python = ">=3.8"
|
||||||
license = {file = "LICENSE"}
|
license = {file = "LICENSE"}
|
||||||
@@ -23,7 +23,8 @@ srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
|||||||
"zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
|
"zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
|
||||||
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
||||||
anthropic = ["anthropic>=0.20.0", "numpy"]
|
anthropic = ["anthropic>=0.20.0", "numpy"]
|
||||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
litellm = ["litellm>=1.0.0"]
|
||||||
|
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from sglang.backend.anthropic import Anthropic
|
|||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.backend.vertexai import VertexAI
|
from sglang.backend.vertexai import VertexAI
|
||||||
|
from sglang.backend.litellm import LiteLLM
|
||||||
|
|
||||||
# Global Configurations
|
# Global Configurations
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -35,6 +36,7 @@ from sglang.global_config import global_config
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"global_config",
|
"global_config",
|
||||||
"Anthropic",
|
"Anthropic",
|
||||||
|
"LiteLLM",
|
||||||
"OpenAI",
|
"OpenAI",
|
||||||
"RuntimeEndpoint",
|
"RuntimeEndpoint",
|
||||||
"VertexAI",
|
"VertexAI",
|
||||||
|
|||||||
89
python/sglang/backend/litellm.py
Normal file
89
python/sglang/backend/litellm.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
from typing import Mapping, Optional
|
||||||
|
|
||||||
|
from sglang.backend.base_backend import BaseBackend
|
||||||
|
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||||
|
from sglang.lang.interpreter import StreamExecutor
|
||||||
|
from sglang.lang.ir import SglSamplingParams
|
||||||
|
|
||||||
|
try:
|
||||||
|
import litellm
|
||||||
|
except ImportError as e:
|
||||||
|
litellm = e
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM(BaseBackend):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name,
|
||||||
|
chat_template=None,
|
||||||
|
api_key=None,
|
||||||
|
organization: Optional[str] = None,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
timeout: Optional[float] = 600,
|
||||||
|
max_retries: Optional[int] = litellm.num_retries,
|
||||||
|
default_headers: Optional[Mapping[str, str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if isinstance(litellm, Exception):
|
||||||
|
raise litellm
|
||||||
|
|
||||||
|
self.model_name = model_name
|
||||||
|
|
||||||
|
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||||
|
model_name)
|
||||||
|
|
||||||
|
self.client_params = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"organization": organization,
|
||||||
|
"base_url": base_url,
|
||||||
|
"timeout": timeout,
|
||||||
|
"max_retries": max_retries,
|
||||||
|
"default_headers": default_headers,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_chat_template(self):
|
||||||
|
return self.chat_template
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
s: StreamExecutor,
|
||||||
|
sampling_params: SglSamplingParams,
|
||||||
|
):
|
||||||
|
if s.messages_:
|
||||||
|
messages = s.messages_
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": s.text_}]
|
||||||
|
|
||||||
|
ret = litellm.completion(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
**self.client_params,
|
||||||
|
**sampling_params.to_anthropic_kwargs(),
|
||||||
|
)
|
||||||
|
comp = ret.choices[0].message.content
|
||||||
|
|
||||||
|
return comp, {}
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
self,
|
||||||
|
s: StreamExecutor,
|
||||||
|
sampling_params: SglSamplingParams,
|
||||||
|
):
|
||||||
|
if s.messages_:
|
||||||
|
messages = s.messages_
|
||||||
|
else:
|
||||||
|
messages = [{"role": "user", "content": s.text_}]
|
||||||
|
|
||||||
|
ret = litellm.completion(
|
||||||
|
model=self.model_name,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
**self.client_params,
|
||||||
|
**sampling_params.to_litellm_kwargs(),
|
||||||
|
)
|
||||||
|
for chunk in ret:
|
||||||
|
text = chunk.choices[0].delta.content
|
||||||
|
if text is not None:
|
||||||
|
yield text, {}
|
||||||
@@ -81,6 +81,21 @@ class SglSamplingParams:
|
|||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def to_litellm_kwargs(self):
|
||||||
|
if self.regex is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Regular expression is not supported in the LiteLLM backend."
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"max_tokens": self.max_new_tokens,
|
||||||
|
"stop": self.stop or None,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"top_k": self.top_k,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
}
|
||||||
|
|
||||||
def to_srt_kwargs(self):
|
def to_srt_kwargs(self):
|
||||||
return {
|
return {
|
||||||
|
|||||||
27
test/lang/test_litellm_backend.py
Normal file
27
test/lang/test_litellm_backend.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang import LiteLLM, set_default_backend
|
||||||
|
from sglang.test.test_programs import test_mt_bench, test_stream
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnthropicBackend(unittest.TestCase):
|
||||||
|
backend = None
|
||||||
|
chat_backend = None
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
cls = type(self)
|
||||||
|
|
||||||
|
if cls.backend is None:
|
||||||
|
cls.backend = LiteLLM("gpt-3.5-turbo")
|
||||||
|
set_default_backend(cls.backend)
|
||||||
|
|
||||||
|
def test_mt_bench(self):
|
||||||
|
test_mt_bench()
|
||||||
|
|
||||||
|
def test_stream(self):
|
||||||
|
test_stream()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
Reference in New Issue
Block a user