Update Readme (#660)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
0
python/sglang/lang/backend/__init__.py
Normal file
0
python/sglang/lang/backend/__init__.py
Normal file
77
python/sglang/lang/backend/anthropic.py
Normal file
77
python/sglang/lang/backend/anthropic.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
anthropic = e
|
||||
|
||||
|
||||
class Anthropic(BaseBackend):
|
||||
def __init__(self, model_name, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(anthropic, Exception):
|
||||
raise anthropic
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("claude")
|
||||
self.client = anthropic.Anthropic(*args, **kwargs)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
ret = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.content[0].text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system = messages.pop(0)["content"]
|
||||
else:
|
||||
system = ""
|
||||
|
||||
with self.client.messages.stream(
|
||||
model=self.model_name,
|
||||
system=system,
|
||||
messages=messages,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
) as stream:
|
||||
for text in stream.text_stream:
|
||||
yield text, {}
|
||||
80
python/sglang/lang/backend/base_backend.py
Normal file
80
python/sglang/lang/backend/base_backend.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
|
||||
class BaseBackend:
|
||||
def __init__(self) -> None:
|
||||
self.support_concate_and_append = False
|
||||
self.chat_template = get_chat_template("default")
|
||||
|
||||
def get_model_name(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: Union[str, List[str]]):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
|
||||
pass
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def fork_program(
|
||||
self,
|
||||
src: StreamExecutor,
|
||||
dst: List[StreamExecutor],
|
||||
position_ids_offset: Optional[List[int]] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
pass
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
raise NotImplementedError()
|
||||
|
||||
def shutdown(self):
|
||||
pass
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
def get_server_args(self):
|
||||
pass
|
||||
90
python/sglang/lang/backend/litellm.py
Normal file
90
python/sglang/lang/backend/litellm.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from typing import Mapping, Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError as e:
|
||||
litellm = e
|
||||
litellm.num_retries = 1
|
||||
|
||||
|
||||
class LiteLLM(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
chat_template=None,
|
||||
api_key=None,
|
||||
organization: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
timeout: Optional[float] = 600,
|
||||
max_retries: Optional[int] = litellm.num_retries,
|
||||
default_headers: Optional[Mapping[str, str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(litellm, Exception):
|
||||
raise litellm
|
||||
|
||||
self.model_name = model_name
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
self.client_params = {
|
||||
"api_key": api_key,
|
||||
"organization": organization,
|
||||
"base_url": base_url,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries,
|
||||
"default_headers": default_headers,
|
||||
}
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
**self.client_params,
|
||||
**sampling_params.to_anthropic_kwargs(),
|
||||
)
|
||||
comp = ret.choices[0].message.content
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
messages = s.messages_
|
||||
else:
|
||||
messages = [{"role": "user", "content": s.text_}]
|
||||
|
||||
ret = litellm.completion(
|
||||
model=self.model_name,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
**self.client_params,
|
||||
**sampling_params.to_litellm_kwargs(),
|
||||
)
|
||||
for chunk in ret:
|
||||
text = chunk.choices[0].delta.content
|
||||
if text is not None:
|
||||
yield text, {}
|
||||
438
python/sglang/lang/backend/openai.py
Normal file
438
python/sglang/lang/backend/openai.py
Normal file
@@ -0,0 +1,438 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import time
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import openai
|
||||
import tiktoken
|
||||
except ImportError as e:
|
||||
openai = tiktoken = e
|
||||
|
||||
|
||||
logger = logging.getLogger("openai")
|
||||
|
||||
|
||||
def create_logit_bias_int(tokenizer):
|
||||
"""Get logit bias for integer numbers."""
|
||||
int_token_ids = []
|
||||
|
||||
tokens = tokenizer._mergeable_ranks
|
||||
for token, token_id in tokens.items():
|
||||
s = tokenizer.decode([token_id])
|
||||
if all([c.isdigit() for c in s]) or s in [" "]:
|
||||
int_token_ids.append(token_id)
|
||||
if len(int_token_ids) >= 300: # OpenAI API limit
|
||||
break
|
||||
special_tokens = tokenizer._special_tokens
|
||||
mask = {t: 100 for t in int_token_ids[:299]}
|
||||
mask[special_tokens["<|endoftext|>"]] = 100
|
||||
return mask
|
||||
|
||||
|
||||
INSTRUCT_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-instruct",
|
||||
]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TokenUsage:
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
|
||||
def reset(self):
|
||||
self.prompt_tokens = self.completion_tokens = 0
|
||||
|
||||
|
||||
class OpenAI(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
is_chat_model: Optional[bool] = None,
|
||||
chat_template: Optional[ChatTemplate] = None,
|
||||
is_azure: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(openai, Exception):
|
||||
raise openai
|
||||
|
||||
if is_azure:
|
||||
self.client = openai.AzureOpenAI(*args, **kwargs)
|
||||
else:
|
||||
self.client = openai.OpenAI(*args, **kwargs)
|
||||
|
||||
self.model_name = model_name
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except KeyError:
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
|
||||
|
||||
self.chat_template = chat_template or get_chat_template_by_model_path(
|
||||
model_name
|
||||
)
|
||||
|
||||
if is_chat_model is not None:
|
||||
self.is_chat_model = is_chat_model
|
||||
else:
|
||||
if model_name in INSTRUCT_MODEL_NAMES:
|
||||
self.is_chat_model = False
|
||||
else:
|
||||
self.is_chat_model = True
|
||||
|
||||
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
|
||||
|
||||
# Usage
|
||||
self.token_usage = TokenUsage(0, 0)
|
||||
|
||||
# API speculative execution
|
||||
# TODO(ying): This does not support multi-threading (run_batch)
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
self.spec_max_num_tries = 3
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def _prepare_spec_execution(
|
||||
self,
|
||||
sampling_params: SglSamplingParams,
|
||||
num_api_spec_tokens: int,
|
||||
spec_var_name: str,
|
||||
):
|
||||
if "max_tokens" not in self.spec_kwargs:
|
||||
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
|
||||
else:
|
||||
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
|
||||
|
||||
params = sampling_params.to_openai_kwargs()
|
||||
for key, value in params.items():
|
||||
if key in ["stop"]:
|
||||
continue
|
||||
if key in ["max_tokens"]:
|
||||
warnings.warn(
|
||||
"The parameter max_tokens will be overwritten by speculated number of tokens."
|
||||
)
|
||||
continue
|
||||
if key not in self.spec_kwargs:
|
||||
self.spec_kwargs[key] = value
|
||||
else:
|
||||
assert (
|
||||
value == self.spec_kwargs[key]
|
||||
), "sampling parameters should be consistent if turn on api speculative execution."
|
||||
self.spec_format.append(
|
||||
{"text": "", "stop": params["stop"], "name": spec_var_name}
|
||||
)
|
||||
return "", {}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
spec_var_name: str = None,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if s.num_api_spec_tokens is None:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported if api speculative execution is off. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
|
||||
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
return self._prepare_spec_execution(
|
||||
sampling_params, s.num_api_spec_tokens, spec_var_name
|
||||
)
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
elif sampling_params.dtype in [str, "str", "string"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_ + '"',
|
||||
stop='"',
|
||||
**kwargs,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
assert (
|
||||
not self.is_chat_model
|
||||
), "constrained type not supported on chat model"
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
kwargs.pop("stop")
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.text_,
|
||||
logit_bias=self.logit_bias_int,
|
||||
stop=[" "],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
return comp, {}
|
||||
|
||||
def spec_fill(self, value: str):
|
||||
assert self.is_chat_model
|
||||
self.spec_format.append({"text": value, "stop": None, "name": None})
|
||||
|
||||
def spec_pattern_match(self, comp):
|
||||
for i, term in enumerate(self.spec_format):
|
||||
text = term["text"]
|
||||
if text != "":
|
||||
if comp.startswith(text):
|
||||
comp = comp[len(text) :]
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
pos = comp.find(term["stop"])
|
||||
if pos != -1:
|
||||
term["text"] = comp[:pos]
|
||||
comp = comp[pos:]
|
||||
else:
|
||||
if i == len(self.spec_format) - 1:
|
||||
term["text"] = comp
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def role_end_generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
):
|
||||
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
|
||||
return
|
||||
|
||||
comp = ""
|
||||
if not all(x["name"] is None for x in self.spec_format):
|
||||
# TODO(ying): throw errors or warnings
|
||||
for i in range(self.spec_max_num_tries):
|
||||
comp = openai_completion(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=s.messages_,
|
||||
**self.spec_kwargs,
|
||||
)
|
||||
if self.spec_pattern_match(comp):
|
||||
break
|
||||
|
||||
for term in self.spec_format:
|
||||
s.text_ += term["text"]
|
||||
name = term["name"]
|
||||
if name is not None:
|
||||
s.variables[name] = term["text"]
|
||||
s.meta_info[name] = {}
|
||||
s.variable_event[name].set()
|
||||
|
||||
self.spec_kwargs = {}
|
||||
self.spec_format = []
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
if self.is_chat_model:
|
||||
if not s.text_.endswith(self.chat_prefix):
|
||||
raise RuntimeError(
|
||||
"This use case is not supported. "
|
||||
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
|
||||
)
|
||||
prompt = s.messages_
|
||||
else:
|
||||
prompt = s.text_
|
||||
|
||||
kwargs = sampling_params.to_openai_kwargs()
|
||||
generator = openai_completion_stream(
|
||||
client=self.client,
|
||||
token_usage=self.token_usage,
|
||||
is_chat=self.is_chat_model,
|
||||
model=self.model_name,
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
return generator
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
if self.is_chat_model:
|
||||
raise NotImplementedError(
|
||||
"select/choices is not supported for chat models. "
|
||||
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
|
||||
)
|
||||
|
||||
n_choices = len(choices)
|
||||
token_ids = [self.tokenizer.encode(x) for x in choices]
|
||||
scores = [0] * n_choices
|
||||
valid = [len(x) > 0 for x in token_ids]
|
||||
prompt_tokens = self.tokenizer.encode(s.text_)
|
||||
|
||||
max_len = max([len(x) for x in token_ids])
|
||||
for step in range(max_len):
|
||||
# Build logit bias
|
||||
logit_bias = {}
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
logit_bias[token_ids[i][step]] = 100
|
||||
|
||||
# Call API
|
||||
ret = self.client.completions.create(
|
||||
model=self.model_name,
|
||||
prompt=prompt_tokens,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=1,
|
||||
temperature=temperature,
|
||||
)
|
||||
ret_str = ret.choices[0].text
|
||||
ret_token = self.tokenizer.encode(ret_str)[0]
|
||||
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
self.token_usage.completion_tokens = ret.usage.completion_tokens
|
||||
|
||||
# TODO:
|
||||
# 1. return logits as the scores
|
||||
# 2. compute logits of the full choice
|
||||
# 3. consider chunk-based decoding
|
||||
|
||||
# Update valid
|
||||
hit = False
|
||||
for i in range(n_choices):
|
||||
if valid[i]:
|
||||
if step == len(token_ids[i]) - 1:
|
||||
valid[i] = False
|
||||
|
||||
if ret_token == token_ids[i][step]:
|
||||
scores[i] += 1
|
||||
hit = True
|
||||
else:
|
||||
valid[i] = False
|
||||
assert hit
|
||||
|
||||
if np.sum(valid) <= 1:
|
||||
break
|
||||
|
||||
prompt_tokens.append(ret_token)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores, None, None
|
||||
|
||||
|
||||
def openai_completion(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
ret = client.chat.completions.create(messages=prompt, **kwargs)
|
||||
comp = ret.choices[0].message.content
|
||||
else:
|
||||
ret = client.completions.create(prompt=prompt, **kwargs)
|
||||
if isinstance(prompt, (list, tuple)):
|
||||
comp = [c.text for c in ret.choices]
|
||||
else:
|
||||
comp = ret.choices[0].text
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
|
||||
return comp
|
||||
|
||||
|
||||
def openai_completion_stream(
|
||||
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
|
||||
):
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
if is_chat:
|
||||
if "stop" in kwargs and kwargs["stop"] is None:
|
||||
kwargs.pop("stop")
|
||||
generator = client.chat.completions.create(
|
||||
messages=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
try:
|
||||
content = ret.choices[0].delta.content
|
||||
except IndexError:
|
||||
content = None
|
||||
yield content or "", {}
|
||||
else:
|
||||
generator = client.completions.create(
|
||||
prompt=prompt,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
**kwargs,
|
||||
)
|
||||
for ret in generator:
|
||||
if len(ret.choices) == 0:
|
||||
continue
|
||||
content = ret.choices[0].text
|
||||
yield content or "", {}
|
||||
|
||||
token_usage.prompt_tokens += ret.usage.prompt_tokens
|
||||
token_usage.completion_tokens += ret.usage.completion_tokens
|
||||
break
|
||||
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
|
||||
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
|
||||
time.sleep(5)
|
||||
if attempt == retries - 1:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"RuntimeError {e}.")
|
||||
raise e
|
||||
283
python/sglang/lang/backend/runtime_endpoint.py
Normal file
283
python/sglang/lang/backend/runtime_endpoint.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class RuntimeEndpoint(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
auth_token: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
verify: Optional[str] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.support_concate_and_append = True
|
||||
|
||||
self.base_url = base_url
|
||||
self.auth_token = auth_token
|
||||
self.api_key = api_key
|
||||
self.verify = verify
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/get_model_info",
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
self.model_info = res.json()
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_path"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_path"]
|
||||
|
||||
def flush_cache(self):
|
||||
res = http_request(
|
||||
self.base_url + "/flush_cache",
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def get_server_args(self):
|
||||
res = http_request(
|
||||
self.base_url + "/get_server_args",
|
||||
auth_token=self.auth_token,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
return res.json()
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def commit_lazy_operations(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def fill_image(self, s: StreamExecutor):
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
obj = res.json()
|
||||
comp = obj["text"]
|
||||
return comp, obj["meta_info"]
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if sampling_params.dtype is None:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
elif sampling_params.dtype in [int, "int"]:
|
||||
data = {
|
||||
"text": s.text_,
|
||||
"sampling_params": {
|
||||
"skip_special_tokens": global_config.skip_special_tokens_in_output,
|
||||
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
|
||||
"dtype": "int",
|
||||
**sampling_params.to_srt_kwargs(),
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||
|
||||
for item in [
|
||||
"return_logprob",
|
||||
"logprob_start_len",
|
||||
"top_logprobs_num",
|
||||
"return_text_in_logprobs",
|
||||
]:
|
||||
value = getattr(sampling_params, item, None)
|
||||
if value is not None:
|
||||
data[item] = value
|
||||
|
||||
data["stream"] = True
|
||||
self._add_images(s, data)
|
||||
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
stream=True,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
pos = 0
|
||||
|
||||
for chunk in res.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
chunk_text = data["text"][pos:]
|
||||
meta_info = data["meta_info"]
|
||||
pos += len(chunk_text)
|
||||
yield chunk_text, meta_info
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
temperature: float,
|
||||
):
|
||||
assert temperature <= 1e-5
|
||||
|
||||
# Cache common prefix
|
||||
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
prompt_len = res.json()["meta_info"]["prompt_tokens"]
|
||||
|
||||
# Compute logprob
|
||||
data = {
|
||||
"text": [s.text_ + c for c in choices],
|
||||
"sampling_params": {"max_new_tokens": 0},
|
||||
"return_logprob": True,
|
||||
"logprob_start_len": max(prompt_len - 2, 0),
|
||||
}
|
||||
self._add_images(s, data)
|
||||
res = http_request(
|
||||
self.base_url + "/generate",
|
||||
json=data,
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
obj = res.json()
|
||||
normalized_prompt_logprobs = [
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
]
|
||||
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
||||
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
||||
|
||||
return (
|
||||
decision,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_token_logprobs,
|
||||
decode_token_logprobs,
|
||||
)
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
res = http_request(
|
||||
self.base_url + "/concate_and_append_request",
|
||||
json={"src_rids": src_rids, "dst_rid": dst_rid},
|
||||
auth_token=self.auth_token,
|
||||
api_key=self.api_key,
|
||||
verify=self.verify,
|
||||
)
|
||||
self._assert_success(res)
|
||||
|
||||
def _add_images(self, s: StreamExecutor, data):
|
||||
if s.images_:
|
||||
assert len(s.images_) == 1, "Only support one image."
|
||||
data["image_data"] = s.images_[0][1]
|
||||
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(res.json())
|
||||
149
python/sglang/lang/backend/vertexai.py
Normal file
149
python/sglang/lang/backend/vertexai.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
|
||||
try:
|
||||
import vertexai
|
||||
from vertexai.preview.generative_models import (
|
||||
GenerationConfig,
|
||||
GenerativeModel,
|
||||
Image,
|
||||
)
|
||||
except ImportError as e:
|
||||
GenerativeModel = e
|
||||
|
||||
|
||||
class VertexAI(BaseBackend):
|
||||
def __init__(self, model_name, safety_settings=None):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(GenerativeModel, Exception):
|
||||
raise GenerativeModel
|
||||
|
||||
project_id = os.environ["GCP_PROJECT_ID"]
|
||||
location = os.environ.get("GCP_LOCATION")
|
||||
vertexai.init(project=project_id, location=location)
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.safety_settings = safety_settings
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
ret = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
comp = ret.text
|
||||
|
||||
return comp, {}
|
||||
|
||||
def generate_stream(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
generator = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
stream=True,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
for ret in generator:
|
||||
yield ret.text, {}
|
||||
|
||||
def text_to_vertexai_input(self, text, images):
|
||||
input = []
|
||||
# split with image token
|
||||
text_segs = text.split(self.chat_template.image_token)
|
||||
for image_path, image_base64_data in images:
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
input.append(Image.from_bytes(image_base64_data))
|
||||
text_seg = text_segs.pop(0)
|
||||
if text_seg != "":
|
||||
input.append(text_seg)
|
||||
return input
|
||||
|
||||
def messages_to_vertexai_input(self, messages):
|
||||
vertexai_message = []
|
||||
# from openai message format to vertexai message format
|
||||
for msg in messages:
|
||||
if isinstance(msg["content"], str):
|
||||
text = msg["content"]
|
||||
else:
|
||||
text = msg["content"][0]["text"]
|
||||
|
||||
if msg["role"] == "system":
|
||||
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "System prompt: " + text}],
|
||||
}
|
||||
)
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [{"text": "Understood."}],
|
||||
}
|
||||
)
|
||||
continue
|
||||
if msg["role"] == "user":
|
||||
vertexai_msg = {
|
||||
"role": "user",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
elif msg["role"] == "assistant":
|
||||
vertexai_msg = {
|
||||
"role": "model",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
|
||||
# images
|
||||
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||
for image in msg["content"][1:]:
|
||||
assert image["type"] == "image_url"
|
||||
vertexai_msg["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": image["image_url"]["url"].split(",")[1],
|
||||
"mime_type": "image/jpeg",
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
vertexai_message.append(vertexai_msg)
|
||||
return vertexai_message
|
||||
@@ -3,8 +3,8 @@
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.global_config import global_config
|
||||
from sglang.lang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
|
||||
from sglang.lang.ir import (
|
||||
SglArgument,
|
||||
|
||||
Reference in New Issue
Block a user