Update Readme (#660)

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
Ying Sheng
2024-07-19 09:54:01 -07:00
committed by GitHub
parent dc4e4a6acc
commit 51fda1439f
25 changed files with 200 additions and 185 deletions

View File

View File

@@ -0,0 +1,77 @@
from typing import List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import anthropic
except ImportError as e:
anthropic = e
class Anthropic(BaseBackend):
def __init__(self, model_name, *args, **kwargs):
super().__init__()
if isinstance(anthropic, Exception):
raise anthropic
self.model_name = model_name
self.chat_template = get_chat_template("claude")
self.client = anthropic.Anthropic(*args, **kwargs)
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
ret = self.client.messages.create(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.content[0].text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
if messages and messages[0]["role"] == "system":
system = messages.pop(0)["content"]
else:
system = ""
with self.client.messages.stream(
model=self.model_name,
system=system,
messages=messages,
**sampling_params.to_anthropic_kwargs(),
) as stream:
for text in stream.text_stream:
yield text, {}

View File

@@ -0,0 +1,80 @@
from typing import Callable, List, Optional, Union
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
class BaseBackend:
def __init__(self) -> None:
self.support_concate_and_append = False
self.chat_template = get_chat_template("default")
def get_model_name(self):
raise NotImplementedError()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: Union[str, List[str]]):
pass
def begin_program(self, s: StreamExecutor):
pass
def end_program(self, s: Union[StreamExecutor, List[StreamExecutor]]):
pass
def commit_lazy_operations(self, s: StreamExecutor):
pass
def fork_program(
self,
src: StreamExecutor,
dst: List[StreamExecutor],
position_ids_offset: Optional[List[int]] = None,
):
pass
def fill_image(self, s: StreamExecutor):
pass
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
raise NotImplementedError()
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
raise NotImplementedError()
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
raise NotImplementedError()
def shutdown(self):
pass
def flush_cache(self):
pass
def get_server_args(self):
pass

View File

@@ -0,0 +1,90 @@
from typing import Mapping, Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import litellm
except ImportError as e:
litellm = e
litellm.num_retries = 1
class LiteLLM(BaseBackend):
def __init__(
self,
model_name,
chat_template=None,
api_key=None,
organization: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[float] = 600,
max_retries: Optional[int] = litellm.num_retries,
default_headers: Optional[Mapping[str, str]] = None,
):
super().__init__()
if isinstance(litellm, Exception):
raise litellm
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
self.client_params = {
"api_key": api_key,
"organization": organization,
"base_url": base_url,
"timeout": timeout,
"max_retries": max_retries,
"default_headers": default_headers,
}
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
**self.client_params,
**sampling_params.to_anthropic_kwargs(),
)
comp = ret.choices[0].message.content
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
messages = s.messages_
else:
messages = [{"role": "user", "content": s.text_}]
ret = litellm.completion(
model=self.model_name,
messages=messages,
stream=True,
**self.client_params,
**sampling_params.to_litellm_kwargs(),
)
for chunk in ret:
text = chunk.choices[0].delta.content
if text is not None:
yield text, {}

View File

@@ -0,0 +1,438 @@
import dataclasses
import logging
import time
import warnings
from typing import Callable, List, Optional, Union
import numpy as np
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import openai
import tiktoken
except ImportError as e:
openai = tiktoken = e
logger = logging.getLogger("openai")
def create_logit_bias_int(tokenizer):
"""Get logit bias for integer numbers."""
int_token_ids = []
tokens = tokenizer._mergeable_ranks
for token, token_id in tokens.items():
s = tokenizer.decode([token_id])
if all([c.isdigit() for c in s]) or s in [" "]:
int_token_ids.append(token_id)
if len(int_token_ids) >= 300: # OpenAI API limit
break
special_tokens = tokenizer._special_tokens
mask = {t: 100 for t in int_token_ids[:299]}
mask[special_tokens["<|endoftext|>"]] = 100
return mask
INSTRUCT_MODEL_NAMES = [
"gpt-3.5-turbo-instruct",
]
@dataclasses.dataclass
class TokenUsage:
prompt_tokens: int
completion_tokens: int
def reset(self):
self.prompt_tokens = self.completion_tokens = 0
class OpenAI(BaseBackend):
def __init__(
self,
model_name: str,
is_chat_model: Optional[bool] = None,
chat_template: Optional[ChatTemplate] = None,
is_azure: bool = False,
*args,
**kwargs,
):
super().__init__()
if isinstance(openai, Exception):
raise openai
if is_azure:
self.client = openai.AzureOpenAI(*args, **kwargs)
else:
self.client = openai.OpenAI(*args, **kwargs)
self.model_name = model_name
try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except KeyError:
self.tokenizer = tiktoken.get_encoding("cl100k_base")
self.logit_bias_int = create_logit_bias_int(self.tokenizer)
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name
)
if is_chat_model is not None:
self.is_chat_model = is_chat_model
else:
if model_name in INSTRUCT_MODEL_NAMES:
self.is_chat_model = False
else:
self.is_chat_model = True
self.chat_prefix = self.chat_template.role_prefix_and_suffix["assistant"][0]
# Usage
self.token_usage = TokenUsage(0, 0)
# API speculative execution
# TODO(ying): This does not support multi-threading (run_batch)
self.spec_kwargs = {}
self.spec_format = []
self.spec_max_num_tries = 3
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
if key in ["stop"]:
continue
if key in ["max_tokens"]:
warnings.warn(
"The parameter max_tokens will be overwritten by speculated number of tokens."
)
continue
if key not in self.spec_kwargs:
self.spec_kwargs[key] = value
else:
assert (
value == self.spec_kwargs[key]
), "sampling parameters should be consistent if turn on api speculative execution."
self.spec_format.append(
{"text": "", "stop": params["stop"], "name": spec_var_name}
)
return "", {}
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
spec_var_name: str = None,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if s.num_api_spec_tokens is None:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported if api speculative execution is off. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant. "
"Example of adding api speculative execution: @function(num_api_spec_tokens=128)."
)
prompt = s.messages_
else:
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
elif sampling_params.dtype in [str, "str", "string"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_ + '"',
stop='"',
**kwargs,
)
comp = '"' + comp + '"'
elif sampling_params.dtype in [int, "int"]:
assert (
not self.is_chat_model
), "constrained type not supported on chat model"
kwargs = sampling_params.to_openai_kwargs()
kwargs.pop("stop")
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.text_,
logit_bias=self.logit_bias_int,
stop=[" "],
**kwargs,
)
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
return comp, {}
def spec_fill(self, value: str):
assert self.is_chat_model
self.spec_format.append({"text": value, "stop": None, "name": None})
def spec_pattern_match(self, comp):
for i, term in enumerate(self.spec_format):
text = term["text"]
if text != "":
if comp.startswith(text):
comp = comp[len(text) :]
else:
return False
else:
pos = comp.find(term["stop"])
if pos != -1:
term["text"] = comp[:pos]
comp = comp[pos:]
else:
if i == len(self.spec_format) - 1:
term["text"] = comp
else:
return False
return True
def role_end_generate(
self,
s: StreamExecutor,
):
if s.num_api_spec_tokens is None or not s.text_.endswith(self.chat_prefix):
return
comp = ""
if not all(x["name"] is None for x in self.spec_format):
# TODO(ying): throw errors or warnings
for i in range(self.spec_max_num_tries):
comp = openai_completion(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=s.messages_,
**self.spec_kwargs,
)
if self.spec_pattern_match(comp):
break
for term in self.spec_format:
s.text_ += term["text"]
name = term["name"]
if name is not None:
s.variables[name] = term["text"]
s.meta_info[name] = {}
s.variable_event[name].set()
self.spec_kwargs = {}
self.spec_format = []
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
if self.is_chat_model:
if not s.text_.endswith(self.chat_prefix):
raise RuntimeError(
"This use case is not supported. "
"For OpenAI chat models, sgl.gen must be right after sgl.assistant"
)
prompt = s.messages_
else:
prompt = s.text_
kwargs = sampling_params.to_openai_kwargs()
generator = openai_completion_stream(
client=self.client,
token_usage=self.token_usage,
is_chat=self.is_chat_model,
model=self.model_name,
prompt=prompt,
**kwargs,
)
return generator
else:
raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
if self.is_chat_model:
raise NotImplementedError(
"select/choices is not supported for chat models. "
"Please try to use a non-chat model such as gpt-3.5-turbo-instruct"
)
n_choices = len(choices)
token_ids = [self.tokenizer.encode(x) for x in choices]
scores = [0] * n_choices
valid = [len(x) > 0 for x in token_ids]
prompt_tokens = self.tokenizer.encode(s.text_)
max_len = max([len(x) for x in token_ids])
for step in range(max_len):
# Build logit bias
logit_bias = {}
for i in range(n_choices):
if valid[i]:
logit_bias[token_ids[i][step]] = 100
# Call API
ret = self.client.completions.create(
model=self.model_name,
prompt=prompt_tokens,
logit_bias=logit_bias,
max_tokens=1,
temperature=temperature,
)
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
# 2. compute logits of the full choice
# 3. consider chunk-based decoding
# Update valid
hit = False
for i in range(n_choices):
if valid[i]:
if step == len(token_ids[i]) - 1:
valid[i] = False
if ret_token == token_ids[i][step]:
scores[i] += 1
hit = True
else:
valid[i] = False
assert hit
if np.sum(valid) <= 1:
break
prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)]
return decision, scores, None, None
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
ret = client.chat.completions.create(messages=prompt, **kwargs)
comp = ret.choices[0].message.content
else:
ret = client.completions.create(prompt=prompt, **kwargs)
if isinstance(prompt, (list, tuple)):
comp = [c.text for c in ret.choices]
else:
comp = ret.choices[0].text
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e
return comp
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
try:
content = ret.choices[0].delta.content
except IndexError:
content = None
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
continue
content = ret.choices[0].text
yield content or "", {}
token_usage.prompt_tokens += ret.usage.prompt_tokens
token_usage.completion_tokens += ret.usage.completion_tokens
break
except (openai.APIError, openai.APIConnectionError, openai.RateLimitError) as e:
logger.error(f"OpenAI Error: {e}. Waiting 5 seconds...")
time.sleep(5)
if attempt == retries - 1:
raise e
except Exception as e:
logger.error(f"RuntimeError {e}.")
raise e

View File

@@ -0,0 +1,283 @@
import json
from typing import List, Optional
import numpy as np
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
auth_token: Optional[str] = None,
api_key: Optional[str] = None,
verify: Optional[str] = None,
):
super().__init__()
self.support_concate_and_append = True
self.base_url = base_url
self.auth_token = auth_token
self.api_key = api_key
self.verify = verify
res = http_request(
self.base_url + "/get_model_info",
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
def flush_cache(self):
res = http_request(
self.base_url + "/flush_cache",
auth_token=self.auth_token,
verify=self.verify,
)
self._assert_success(res)
def get_server_args(self):
res = http_request(
self.base_url + "/get_server_args",
auth_token=self.auth_token,
verify=self.verify,
)
self._assert_success(res)
return res.json()
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
res = http_request(
self.base_url + "/generate",
json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
comp = obj["text"]
return comp, obj["meta_info"]
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if sampling_params.dtype is None:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
**sampling_params.to_srt_kwargs(),
},
}
elif sampling_params.dtype in [int, "int"]:
data = {
"text": s.text_,
"sampling_params": {
"skip_special_tokens": global_config.skip_special_tokens_in_output,
"spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
"dtype": "int",
**sampling_params.to_srt_kwargs(),
},
}
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
data["stream"] = True
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
stream=True,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
pos = 0
for chunk in res.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
chunk_text = data["text"][pos:]
meta_info = data["meta_info"]
pos += len(chunk_text)
yield chunk_text, meta_info
def select(
self,
s: StreamExecutor,
choices: List[str],
temperature: float,
):
assert temperature <= 1e-5
# Cache common prefix
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
prompt_len = res.json()["meta_info"]["prompt_tokens"]
# Compute logprob
data = {
"text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0},
"return_logprob": True,
"logprob_start_len": max(prompt_len - 2, 0),
}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
obj = res.json()
normalized_prompt_logprobs = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
return (
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request(
self.base_url + "/concate_and_append_request",
json={"src_rids": src_rids, "dst_rid": dst_rid},
auth_token=self.auth_token,
api_key=self.api_key,
verify=self.verify,
)
self._assert_success(res)
def _add_images(self, s: StreamExecutor, data):
if s.images_:
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _assert_success(self, res):
if res.status_code != 200:
raise RuntimeError(res.json())

View File

@@ -0,0 +1,149 @@
import os
import warnings
from typing import Optional
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
try:
import vertexai
from vertexai.preview.generative_models import (
GenerationConfig,
GenerativeModel,
Image,
)
except ImportError as e:
GenerativeModel = e
class VertexAI(BaseBackend):
def __init__(self, model_name, safety_settings=None):
super().__init__()
if isinstance(GenerativeModel, Exception):
raise GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ.get("GCP_LOCATION")
vertexai.init(project=project_id, location=location)
self.model_name = model_name
self.chat_template = get_chat_template("default")
self.safety_settings = safety_settings
def get_chat_template(self):
return self.chat_template
def generate(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
comp = ret.text
return comp, {}
def generate_stream(
self,
s: StreamExecutor,
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
safety_settings=self.safety_settings,
)
for ret in generator:
yield ret.text, {}
def text_to_vertexai_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
for image_path, image_base64_data in images:
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
input.append(Image.from_bytes(image_base64_data))
text_seg = text_segs.pop(0)
if text_seg != "":
input.append(text_seg)
return input
def messages_to_vertexai_input(self, messages):
vertexai_message = []
# from openai message format to vertexai message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
else:
text = msg["content"][0]["text"]
if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in VertexAI.")
vertexai_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
vertexai_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
}
)
continue
if msg["role"] == "user":
vertexai_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
vertexai_msg = {
"role": "model",
"parts": [{"text": text}],
}
# images
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
vertexai_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
"mime_type": "image/jpeg",
}
}
)
vertexai_message.append(vertexai_msg)
return vertexai_message

View File

@@ -3,8 +3,8 @@
import uuid
from typing import Any, Callable, Dict, List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.global_config import global_config
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.interpreter import ProgramState, ProgramStateGroup
from sglang.lang.ir import (
SglArgument,