Improve docs & Rename Gemini -> VertexAI (#19)

This commit is contained in:
Lianmin Zheng
2024-01-17 02:54:41 -08:00
committed by GitHub
parent fd7c479239
commit bf51ddc6e5
13 changed files with 56 additions and 583 deletions

View File

@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
from sglang.backend.anthropic import Anthropic
from sglang.backend.base_backend import BaseBackend
from sglang.backend.gemini import Gemini
from sglang.backend.vertexai import VertexAI
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.global_config import global_config

View File

@@ -1,349 +0,0 @@
import functools
from enum import Enum, auto
from typing import Callable, List, Optional, Union
import numpy as np
import torch
import transformers
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import ProgramState
from sglang.utils import get_available_gpu_memory
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
from transformersgl.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
class StopReason(Enum):
EOS_TOKEN = auto()
STOP_STR = auto()
LENGTH = auto()
def load_model(
model_name: str,
device,
num_gpus,
max_gpu_memory,
model_kwargs=None,
tokenizer_kwargs=None,
):
model_kwargs = model_kwargs or {}
tokenizer_kwargs = tokenizer_kwargs or {}
if device == "cuda":
model_kwargs["torch_dtype"] = torch.float16
if num_gpus != 1:
model_kwargs["device_map"] = "auto"
if max_gpu_memory is None:
model_kwargs[
"device_map"
] = "sequential" # This is important for not the same VRAM sizes
available_gpu_memory = [
get_available_gpu_memory(i, False) for i in range(num_gpus)
]
model_kwargs["max_memory"] = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
for i in range(num_gpus)
}
else:
model_kwargs["max_memory"] = {
i: max_gpu_memory for i in range(num_gpus)
}
elif device == "cpu":
model_kwargs["torch_dtype"] = torch.float32
else:
raise ValueError(f"Invalid device: {device}")
model = AutoModelForCausalLM.from_pretrained(
model_name, low_cpu_mem_usage=True, **model_kwargs
)
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
if num_gpus == 1:
model.to(device).eval()
return model, tokenizer
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@functools.lru_cache
def get_token_healing_mask(tokenizer, prompt_last_token):
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
if not s.startswith(last_str):
disallowed[t_id] = 1
return disallowed
@functools.lru_cache
def get_int_token_mask(tokenizer):
disallowed = torch.zeros(len(tokenizer), dtype=bool)
for s, t_id in tokenizer.get_vocab().items():
s = s.replace("", "").strip()
if not (s.isdigit() or len(s) == 0 or s == ","):
disallowed[t_id] = 1
disallowed[tokenizer.eos_token_id] = 0
return disallowed
@torch.inference_mode()
def generate_stream(
model,
tokenizer,
prompt,
max_new_tokens,
stop: List[str],
temperature,
top_p,
token_healing,
logit_mask=None,
):
logits_processor = prepare_logits_processor(
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
)
device = model.device
input_ids = tokenizer.encode(prompt)
output_ids = list(input_ids)
prompt_len = len(prompt)
# Resolve stop
stop_token_ids = [tokenizer.eos_token_id]
# Token healing
token_healing = token_healing and len(input_ids) > 0
if token_healing:
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
del output_ids[-1]
# Generate
past_key_values = None
stop_reason = None
for i in range(max_new_tokens):
# Forward
if i == 0: # prefill
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
else: # decoding
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
# Logit mask
if token_healing and i == 0:
logits[0, -1, token_healing_mask] = -1e4
if logit_mask is not None:
logits[0, -1, logit_mask] = -1e4
# Sample next token
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
if temperature < 1e-5 or top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
# Stop condition
if token in stop_token_ids:
stop_reason = StopReason.EOS_TOKEN
break
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
for stop_str in stop:
pos = output_str[prompt_len:].find(stop_str)
if pos != -1:
stop_reason = StopReason.STOP_STR
output_str = output_str[: prompt_len + pos]
break
if stop_reason:
break
return output_str[prompt_len:]
class HuggingFaceTransformers(BaseBackend):
def __init__(
self,
model_name,
device="cuda",
num_gpus=1,
max_gpu_memory=None,
model_kwargs=None,
tokenizer_kwargs=None,
):
self.model_name = model_name
self.device = device
self.model, self.tokenizer = load_model(
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
)
self.chat_template = get_chat_template_by_model_path(model_name)
def get_chat_template(self):
return self.chat_template
def cache_prefix(self, prefix_str: str):
pass
def uncache_prefix(self, rid: str):
pass
def end_request(self, rid: str):
pass
def begin_program(self, s: ProgramState):
pass
def end_program(self, s: ProgramState):
pass
def fill(self, s: ProgramState, text: str):
return False
def generate_internal(
self,
prompt: str,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
if dtype is None:
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop,
temperature=temperature,
top_p=top_p,
token_healing=True,
)
elif dtype in [str, "str", "string"]:
comp = generate_stream(
self.model,
self.tokenizer,
prompt + '"',
max_new_tokens=max_tokens,
stop=['"'],
temperature=temperature,
top_p=top_p,
token_healing=False,
)
comp = '"' + comp + '"'
elif dtype in [int, "int"]:
logit_mask = get_int_token_mask(self.tokenizer)
comp = generate_stream(
self.model,
self.tokenizer,
prompt,
max_new_tokens=max_tokens,
stop=stop + [" ", ","],
temperature=temperature,
top_p=top_p,
token_healing=False,
logit_mask=logit_mask,
)
return comp
def generate(
self,
s: ProgramState,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
comp = self.generate_internal(
prompt, max_tokens, stop, temperature, top_p, dtype
)
return comp
def parallel_generate(
self,
s: ProgramState,
prefixes: List[str],
join_func: Callable,
max_tokens: int,
stop: Union[str, List[str]],
temperature: float,
top_p: float,
dtype: Optional[str] = None,
):
prompt = s.text
parallel_prompts = [prompt + prefix for prefix in prefixes]
comps = []
for i in range(len(parallel_prompts)):
comps.append(
self.generate_internal(
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
)
)
joined = join_func([p + c for p, c in zip(prefixes, comps)])
return joined, comps
@torch.inference_mode()
def select(
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
):
loss_fct = torch.nn.CrossEntropyLoss()
prompt = s.text
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
prompt_choices = [prompt + choice for choice in choices]
scores = []
for i in range(len(choices)):
choice_ids = self.tokenizer.encode(
prompt_choices[i], return_tensors="pt"
).to(self.model.device)
logits = self.model(choice_ids).logits
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
logprobs = torch.log(torch.softmax(logits, dim=-1))
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
idx2 = choice_ids[0, 1:]
selected_logprobs = logprobs[0, idx1, idx2]
score = selected_logprobs.mean().item()
scores.append(score)
decision = choices[np.argmax(scores)]
return decision, scores

View File

@@ -1,190 +0,0 @@
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from itertools import repeat
from typing import List, Optional, Union
from sglang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template_by_model_path
from sglang.lang.interpreter import StreamExecutor
from sglang.lang.ir import SglSamplingParams
from sglang.utils import http_request
class TGI(BaseBackend):
def __init__(self, base_url):
super().__init__()
self.base_url = base_url
res = http_request(self.base_url + "/info")
assert res.status_code == 200
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_id"]
)
def get_model_name(self):
return self.model_info["model_id"]
def get_chat_template(self):
return self.chat_template
@staticmethod
def adapt_params(max_tokens, stop, sampling_params, **override_params):
temperature = sampling_params.temperature
do_sample = True
if temperature == 0:
do_sample = False
temperature = None
if stop is None:
stop = []
elif isinstance(stop, str):
stop = [stop]
top_p = sampling_params.top_p
if top_p == 0:
top_p = 0.001
if top_p == 1:
top_p = 0.999
top_k = sampling_params.top_k
if top_k == -1:
top_k = None
params = {
"decoder_input_details": False,
"details": False,
"do_sample": do_sample,
"max_new_tokens": max_tokens,
"stop": stop,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"return_full_text": False,
}
params.update(override_params)
return params
@staticmethod
def _extract_int(text):
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
try:
int(word)
return word
except ValueError:
continue
raise ValueError
@staticmethod
def _extract_choice(choices, text):
# FIXME: Current only support the case where the choices are single words.
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
for word in words:
if word in choices:
return word
raise ValueError
@staticmethod
def _truncate_to_stop(text, stop):
# The stop sequence may not be a single token. In this case TGI will generate
# too many tokens so we need to truncate the output.
if stop:
stop = [stop] if isinstance(stop, str) else stop
for stop_seq in stop:
pos = text.find(stop_seq)
if pos != -1:
return text[:pos]
return text
def _make_request(self, params):
res = http_request(self.base_url + "/generate", json=params)
if res.status_code != 200:
raise ValueError(f"Error from TGI backend: {res.text}")
return res.json()
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
failed = []
while retry > 0:
res_json = self._make_request(
{
"inputs": prompt,
"parameters": params,
}
)
text = res_json["generated_text"]
try:
return extract_fn(text)
except ValueError:
retry -= 1
failed.append(text)
msg = "=" * 20 + "\n"
msg += f"Prompt:\n{prompt}\n"
msg += "=" * 20 + "\n"
for i, text in enumerate(failed):
msg += f"====== Try {i+1}:\n{text}\n"
raise ValueError(
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
"expected output. Please improve the prompt, increase the temperature, or "
f"use different models.\n{msg}"
)
def select(
self,
s: StreamExecutor,
choices: List[str],
sampling_params: SglSamplingParams,
):
decision = self.retry_for_expected(
s.text_,
self.adapt_params(16, [], sampling_params),
partial(self._extract_choice, choices),
)
return decision, [1 if choice == decision else 0 for choice in choices]
def generate(
self,
s: StreamExecutor,
max_tokens: int,
stop: Union[str, List[str]],
sampling_params: SglSamplingParams,
dtype: Optional[str] = None,
):
if dtype is None:
res_json = self._make_request(
{
"inputs": s.text_,
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return self._truncate_to_stop(res_json["generated_text"], stop), {}
if dtype in [str, "str", "string"]:
stop = ['"']
res_json = self._make_request(
{
"inputs": f'{s.text_}"',
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
}
)
return (
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
{},
)
if dtype in [int, "int"]:
return (
self.retry_for_expected(
s.text_,
self.adapt_params(max_tokens, stop, sampling_params),
self._extract_int,
),
{},
)
raise ValueError(f"Unknown dtype: {dtype}")

View File

@@ -18,13 +18,8 @@ try:
except ImportError as e:
GenerativeModel = e
GEMINI_MODEL_NAMES = [
"gemini-pro",
"gemini-pro-vision",
]
class Gemini(BaseBackend):
class VertexAI(BaseBackend):
def __init__(self, model_name):
super().__init__()
@@ -32,7 +27,7 @@ class Gemini(BaseBackend):
raise GenerativeModel
project_id = os.environ["GCP_PROJECT_ID"]
location = os.environ["GCP_LOCATION"]
location = os.environ.get("GCP_LOCATION")
vertexai.init(project=project_id, location=location)
self.model_name = model_name
@@ -47,17 +42,17 @@ class Gemini(BaseBackend):
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_gemini_input(s.messages_)
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_gemini_input(s.text_, s.cur_images)
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
ret = GenerativeModel(self.model_name).generate_content(
prompt,
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
)
comp = ret.text
@@ -70,23 +65,23 @@ class Gemini(BaseBackend):
sampling_params: SglSamplingParams,
):
if s.messages_:
prompt = self.messages_to_gemini_input(s.messages_)
prompt = self.messages_to_vertexai_input(s.messages_)
else:
# single-turn
prompt = (
self.text_to_gemini_input(s.text_, s.cur_images)
self.text_to_vertexai_input(s.text_, s.cur_images)
if s.cur_images
else s.text_
)
generator = GenerativeModel(self.model_name).generate_content(
prompt,
stream=True,
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
)
for ret in generator:
yield ret.text, {}
def text_to_gemini_input(self, text, images):
def text_to_vertexai_input(self, text, images):
input = []
# split with image token
text_segs = text.split(self.chat_template.image_token)
@@ -100,9 +95,9 @@ class Gemini(BaseBackend):
input.append(text_seg)
return input
def messages_to_gemini_input(self, messages):
gemini_message = []
# from openai message format to gemini message format
def messages_to_vertexai_input(self, messages):
vertexai_message = []
# from openai message format to vertexai message format
for msg in messages:
if isinstance(msg["content"], str):
text = msg["content"]
@@ -110,14 +105,14 @@ class Gemini(BaseBackend):
text = msg["content"][0]["text"]
if msg["role"] == "system":
warnings.warn("Warning: system prompt is not supported in Gemini.")
gemini_message.append(
warnings.warn("Warning: system prompt is not supported in VertexAI.")
vertexai_message.append(
{
"role": "user",
"parts": [{"text": "System prompt: " + text}],
}
)
gemini_message.append(
vertexai_message.append(
{
"role": "model",
"parts": [{"text": "Understood."}],
@@ -125,12 +120,12 @@ class Gemini(BaseBackend):
)
continue
if msg["role"] == "user":
gemini_msg = {
vertexai_msg = {
"role": "user",
"parts": [{"text": text}],
}
elif msg["role"] == "assistant":
gemini_msg = {
vertexai_msg = {
"role": "model",
"parts": [{"text": text}],
}
@@ -139,7 +134,7 @@ class Gemini(BaseBackend):
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
for image in msg["content"][1:]:
assert image["type"] == "image_url"
gemini_msg["parts"].append(
vertexai_msg["parts"].append(
{
"inline_data": {
"data": image["image_url"]["url"].split(",")[1],
@@ -148,5 +143,5 @@ class Gemini(BaseBackend):
}
)
gemini_message.append(gemini_msg)
return gemini_message
vertexai_message.append(vertexai_msg)
return vertexai_message

View File

@@ -2,6 +2,7 @@
import dataclasses
import inspect
import warnings
from typing import List, Optional, Union
from sglang.global_config import global_config
@@ -40,6 +41,8 @@ class SglSamplingParams:
def to_openai_kwargs(self):
# OpenAI does not support top_k, so we drop it here
if self.regex is not None:
warnings.warn("Regular expression is not supported in the OpenAI backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
@@ -49,7 +52,9 @@ class SglSamplingParams:
"presence_penalty": self.presence_penalty,
}
def to_gemini_kwargs(self):
def to_vertexai_kwargs(self):
if self.regex is not None:
warnings.warn("Regular expression is not supported in the VertexAI backend.")
return {
"candidate_count": 1,
"max_output_tokens": self.max_new_tokens,
@@ -61,6 +66,8 @@ class SglSamplingParams:
def to_anthropic_kwargs(self):
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
if self.regex is not None:
warnings.warn("Regular expression is not supported in the Anthropic backend.")
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop,

View File

@@ -28,7 +28,7 @@ class RouterManager:
self.model_client = model_client
self.recv_reqs = []
# Init Some Configs
# Init some configs
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
async def loop_for_forward(self):
@@ -46,7 +46,7 @@ class RouterManager:
if has_finished:
await asyncio.sleep(self.extend_dependency_time)
await asyncio.sleep(0.001)
await asyncio.sleep(0.0006)
async def loop_for_recv_requests(self):
while True:

View File

@@ -108,7 +108,7 @@ class ModelRpcServer(rpyc.Service):
self.running_batch: Batch = None
self.out_pyobjs = []
self.decode_forward_ct = 0
self.stream_interval = 2
self.stream_interval = server_args.stream_interval
# Init the FSM cache for constrained generation
self.regex_fsm_cache = FSMCache(self.tokenizer)

View File

@@ -17,6 +17,7 @@ class ServerArgs:
model_mode: List[str] = ()
schedule_heuristic: str = "lpm"
random_seed: int = 42
stream_interval: int = 2
disable_log_stats: bool = False
log_stats_interval: int = 10
log_level: str = "info"
@@ -108,6 +109,12 @@ class ServerArgs:
default=ServerArgs.random_seed,
help="Random seed.",
)
parser.add_argument(
"--stream-interval",
type=int,
default=ServerArgs.random_seed,
help="The interval in terms of token length for streaming",
)
parser.add_argument(
"--log-level",
type=str,