Litellm Backend (#502)
This commit is contained in:
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
|
||||
[project]
|
||||
name = "sglang"
|
||||
version = "0.1.16"
|
||||
description = "A structured generation langauge for LLMs."
|
||||
description = "A structured generation langauge for LLMs."
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.8"
|
||||
license = {file = "LICENSE"}
|
||||
@@ -23,7 +23,8 @@ srt = ["aiohttp", "fastapi", "psutil", "rpyc", "torch", "uvloop", "uvicorn",
|
||||
"zmq", "vllm==0.4.3", "interegular", "pydantic", "pillow", "packaging", "huggingface_hub", "hf_transfer", "outlines>=0.0.34"]
|
||||
openai = ["openai>=1.0", "numpy", "tiktoken"]
|
||||
anthropic = ["anthropic>=0.20.0", "numpy"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]"]
|
||||
litellm = ["litellm>=1.0.0"]
|
||||
all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"]
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://github.com/sgl-project/sglang"
|
||||
|
||||
@@ -27,6 +27,7 @@ from sglang.backend.anthropic import Anthropic
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.backend.vertexai import VertexAI
|
||||
from sglang.backend.litellm import LiteLLM
|
||||
|
||||
# Global Configurations
|
||||
from sglang.global_config import global_config
|
||||
@@ -35,6 +36,7 @@ from sglang.global_config import global_config
|
||||
__all__ = [
|
||||
"global_config",
|
||||
"Anthropic",
|
||||
"LiteLLM",
|
||||
"OpenAI",
|
||||
"RuntimeEndpoint",
|
||||
"VertexAI",
|
||||
|
||||
89
python/sglang/backend/litellm.py
Normal file
89
python/sglang/backend/litellm.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from typing import Mapping, Optional
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError as e:
|
||||
litellm = e
|
||||
|
||||
|
||||
class LiteLLM(BaseBackend):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
chat_template=None,
|
||||
api_key=None,
|
||||
organization: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[float] = 600,
|
||||
max_retries: Optional[int] = litellm.num_retries,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(litellm, Exception):
|
||||
raise litellm
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name)
|
||||
|
||||
self.client_params = {
|
||||
"api_key": api_key,
|
||||
"organization": organization,
|
||||
"base_url": base_url,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries,
|
||||
"default_headers": default_headers,
|
||||
}
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**self.client_params,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.choices[0].message.content
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
for chunk in ret:
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is not None:
|
||||
yield text, {}
|
||||
@@ -81,6 +81,21 @@ class SglSamplingParams:
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
def to_litellm_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn(
|
||||
"Regular expression is not supported in the LiteLLM backend."
|
||||
)
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_srt_kwargs(self):
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user