Improve docs & Rename Gemini -> VertexAI (#19)
This commit is contained in:
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
||||
|
||||
from sglang.backend.anthropic import Anthropic
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.backend.gemini import Gemini
|
||||
from sglang.backend.vertexai import VertexAI
|
||||
from sglang.backend.openai import OpenAI
|
||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.global_config import global_config
|
||||
|
||||
@@ -1,349 +0,0 @@
|
||||
import functools
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import ProgramState
|
||||
from sglang.utils import get_available_gpu_memory
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
from transformersgl.generation.logits_process import (
|
||||
LogitsProcessorList,
|
||||
RepetitionPenaltyLogitsProcessor,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
|
||||
|
||||
class StopReason(Enum):
|
||||
EOS_TOKEN = auto()
|
||||
STOP_STR = auto()
|
||||
LENGTH = auto()
|
||||
|
||||
|
||||
def load_model(
|
||||
model_name: str,
|
||||
device,
|
||||
num_gpus,
|
||||
max_gpu_memory,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
model_kwargs = model_kwargs or {}
|
||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
||||
|
||||
if device == "cuda":
|
||||
model_kwargs["torch_dtype"] = torch.float16
|
||||
if num_gpus != 1:
|
||||
model_kwargs["device_map"] = "auto"
|
||||
if max_gpu_memory is None:
|
||||
model_kwargs[
|
||||
"device_map"
|
||||
] = "sequential" # This is important for not the same VRAM sizes
|
||||
available_gpu_memory = [
|
||||
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
||||
]
|
||||
model_kwargs["max_memory"] = {
|
||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
||||
for i in range(num_gpus)
|
||||
}
|
||||
else:
|
||||
model_kwargs["max_memory"] = {
|
||||
i: max_gpu_memory for i in range(num_gpus)
|
||||
}
|
||||
elif device == "cpu":
|
||||
model_kwargs["torch_dtype"] = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Invalid device: {device}")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, low_cpu_mem_usage=True, **model_kwargs
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
||||
|
||||
if num_gpus == 1:
|
||||
model.to(device).eval()
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def prepare_logits_processor(
|
||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
||||
) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
||||
if temperature >= 1e-5 and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if repetition_penalty > 1.0:
|
||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
||||
if 1e-8 <= top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
if top_k > 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
return processor_list
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_token_healing_mask(tokenizer, prompt_last_token):
|
||||
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
if not s.startswith(last_str):
|
||||
disallowed[t_id] = 1
|
||||
return disallowed
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def get_int_token_mask(tokenizer):
|
||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
||||
for s, t_id in tokenizer.get_vocab().items():
|
||||
s = s.replace("▁", "").strip()
|
||||
if not (s.isdigit() or len(s) == 0 or s == ","):
|
||||
disallowed[t_id] = 1
|
||||
disallowed[tokenizer.eos_token_id] = 0
|
||||
return disallowed
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate_stream(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
max_new_tokens,
|
||||
stop: List[str],
|
||||
temperature,
|
||||
top_p,
|
||||
token_healing,
|
||||
logit_mask=None,
|
||||
):
|
||||
logits_processor = prepare_logits_processor(
|
||||
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
||||
)
|
||||
device = model.device
|
||||
input_ids = tokenizer.encode(prompt)
|
||||
output_ids = list(input_ids)
|
||||
prompt_len = len(prompt)
|
||||
|
||||
# Resolve stop
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
|
||||
# Token healing
|
||||
token_healing = token_healing and len(input_ids) > 0
|
||||
if token_healing:
|
||||
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
||||
del output_ids[-1]
|
||||
|
||||
# Generate
|
||||
past_key_values = None
|
||||
stop_reason = None
|
||||
for i in range(max_new_tokens):
|
||||
# Forward
|
||||
if i == 0: # prefill
|
||||
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
||||
else: # decoding
|
||||
out = model(
|
||||
input_ids=torch.as_tensor([[token]], device=device),
|
||||
use_cache=True,
|
||||
past_key_values=past_key_values,
|
||||
)
|
||||
logits = out.logits
|
||||
past_key_values = out.past_key_values
|
||||
|
||||
# Logit mask
|
||||
if token_healing and i == 0:
|
||||
logits[0, -1, token_healing_mask] = -1e4
|
||||
if logit_mask is not None:
|
||||
logits[0, -1, logit_mask] = -1e4
|
||||
|
||||
# Sample next token
|
||||
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
||||
token = int(torch.argmax(last_token_logits))
|
||||
else:
|
||||
probs = torch.softmax(last_token_logits, dim=-1)
|
||||
token = int(torch.multinomial(probs, num_samples=1))
|
||||
output_ids.append(token)
|
||||
|
||||
# Stop condition
|
||||
if token in stop_token_ids:
|
||||
stop_reason = StopReason.EOS_TOKEN
|
||||
break
|
||||
|
||||
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
for stop_str in stop:
|
||||
pos = output_str[prompt_len:].find(stop_str)
|
||||
if pos != -1:
|
||||
stop_reason = StopReason.STOP_STR
|
||||
output_str = output_str[: prompt_len + pos]
|
||||
break
|
||||
|
||||
if stop_reason:
|
||||
break
|
||||
|
||||
return output_str[prompt_len:]
|
||||
|
||||
|
||||
class HuggingFaceTransformers(BaseBackend):
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
device="cuda",
|
||||
num_gpus=1,
|
||||
max_gpu_memory=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
|
||||
self.model, self.tokenizer = load_model(
|
||||
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
||||
)
|
||||
|
||||
self.chat_template = get_chat_template_by_model_path(model_name)
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
def cache_prefix(self, prefix_str: str):
|
||||
pass
|
||||
|
||||
def uncache_prefix(self, rid: str):
|
||||
pass
|
||||
|
||||
def end_request(self, rid: str):
|
||||
pass
|
||||
|
||||
def begin_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def end_program(self, s: ProgramState):
|
||||
pass
|
||||
|
||||
def fill(self, s: ProgramState, text: str):
|
||||
return False
|
||||
|
||||
def generate_internal(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=True,
|
||||
)
|
||||
elif dtype in [str, "str", "string"]:
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt + '"',
|
||||
max_new_tokens=max_tokens,
|
||||
stop=['"'],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
)
|
||||
comp = '"' + comp + '"'
|
||||
elif dtype in [int, "int"]:
|
||||
logit_mask = get_int_token_mask(self.tokenizer)
|
||||
comp = generate_stream(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
stop=stop + [" ", ","],
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
token_healing=False,
|
||||
logit_mask=logit_mask,
|
||||
)
|
||||
return comp
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
comp = self.generate_internal(
|
||||
prompt, max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
return comp
|
||||
|
||||
def parallel_generate(
|
||||
self,
|
||||
s: ProgramState,
|
||||
prefixes: List[str],
|
||||
join_func: Callable,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
prompt = s.text
|
||||
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
||||
|
||||
comps = []
|
||||
for i in range(len(parallel_prompts)):
|
||||
comps.append(
|
||||
self.generate_internal(
|
||||
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
||||
)
|
||||
)
|
||||
|
||||
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
||||
return joined, comps
|
||||
|
||||
@torch.inference_mode()
|
||||
def select(
|
||||
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
||||
):
|
||||
loss_fct = torch.nn.CrossEntropyLoss()
|
||||
prompt = s.text
|
||||
|
||||
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
||||
prompt_choices = [prompt + choice for choice in choices]
|
||||
|
||||
scores = []
|
||||
for i in range(len(choices)):
|
||||
choice_ids = self.tokenizer.encode(
|
||||
prompt_choices[i], return_tensors="pt"
|
||||
).to(self.model.device)
|
||||
logits = self.model(choice_ids).logits
|
||||
|
||||
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
||||
|
||||
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
||||
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
||||
idx2 = choice_ids[0, 1:]
|
||||
selected_logprobs = logprobs[0, idx1, idx2]
|
||||
score = selected_logprobs.mean().item()
|
||||
scores.append(score)
|
||||
|
||||
decision = choices[np.argmax(scores)]
|
||||
return decision, scores
|
||||
@@ -1,190 +0,0 @@
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from itertools import repeat
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
||||
from sglang.lang.interpreter import StreamExecutor
|
||||
from sglang.lang.ir import SglSamplingParams
|
||||
from sglang.utils import http_request
|
||||
|
||||
|
||||
class TGI(BaseBackend):
|
||||
def __init__(self, base_url):
|
||||
super().__init__()
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
res = http_request(self.base_url + "/info")
|
||||
assert res.status_code == 200
|
||||
self.model_info = res.json()
|
||||
self.chat_template = get_chat_template_by_model_path(
|
||||
self.model_info["model_id"]
|
||||
)
|
||||
|
||||
def get_model_name(self):
|
||||
return self.model_info["model_id"]
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
|
||||
@staticmethod
|
||||
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
||||
temperature = sampling_params.temperature
|
||||
do_sample = True
|
||||
if temperature == 0:
|
||||
do_sample = False
|
||||
temperature = None
|
||||
|
||||
if stop is None:
|
||||
stop = []
|
||||
elif isinstance(stop, str):
|
||||
stop = [stop]
|
||||
|
||||
top_p = sampling_params.top_p
|
||||
if top_p == 0:
|
||||
top_p = 0.001
|
||||
if top_p == 1:
|
||||
top_p = 0.999
|
||||
|
||||
top_k = sampling_params.top_k
|
||||
if top_k == -1:
|
||||
top_k = None
|
||||
|
||||
params = {
|
||||
"decoder_input_details": False,
|
||||
"details": False,
|
||||
"do_sample": do_sample,
|
||||
"max_new_tokens": max_tokens,
|
||||
"stop": stop,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"top_k": top_k,
|
||||
"return_full_text": False,
|
||||
}
|
||||
params.update(override_params)
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def _extract_int(text):
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
try:
|
||||
int(word)
|
||||
return word
|
||||
except ValueError:
|
||||
continue
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _extract_choice(choices, text):
|
||||
# FIXME: Current only support the case where the choices are single words.
|
||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
||||
for word in words:
|
||||
if word in choices:
|
||||
return word
|
||||
raise ValueError
|
||||
|
||||
@staticmethod
|
||||
def _truncate_to_stop(text, stop):
|
||||
# The stop sequence may not be a single token. In this case TGI will generate
|
||||
# too many tokens so we need to truncate the output.
|
||||
if stop:
|
||||
stop = [stop] if isinstance(stop, str) else stop
|
||||
for stop_seq in stop:
|
||||
pos = text.find(stop_seq)
|
||||
if pos != -1:
|
||||
return text[:pos]
|
||||
return text
|
||||
|
||||
def _make_request(self, params):
|
||||
res = http_request(self.base_url + "/generate", json=params)
|
||||
if res.status_code != 200:
|
||||
raise ValueError(f"Error from TGI backend: {res.text}")
|
||||
return res.json()
|
||||
|
||||
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
||||
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
||||
failed = []
|
||||
while retry > 0:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": prompt,
|
||||
"parameters": params,
|
||||
}
|
||||
)
|
||||
text = res_json["generated_text"]
|
||||
try:
|
||||
return extract_fn(text)
|
||||
except ValueError:
|
||||
retry -= 1
|
||||
failed.append(text)
|
||||
|
||||
msg = "=" * 20 + "\n"
|
||||
msg += f"Prompt:\n{prompt}\n"
|
||||
msg += "=" * 20 + "\n"
|
||||
for i, text in enumerate(failed):
|
||||
msg += f"====== Try {i+1}:\n{text}\n"
|
||||
|
||||
raise ValueError(
|
||||
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
||||
"expected output. Please improve the prompt, increase the temperature, or "
|
||||
f"use different models.\n{msg}"
|
||||
)
|
||||
|
||||
def select(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
choices: List[str],
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
decision = self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(16, [], sampling_params),
|
||||
partial(self._extract_choice, choices),
|
||||
)
|
||||
return decision, [1 if choice == decision else 0 for choice in choices]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
s: StreamExecutor,
|
||||
max_tokens: int,
|
||||
stop: Union[str, List[str]],
|
||||
sampling_params: SglSamplingParams,
|
||||
dtype: Optional[str] = None,
|
||||
):
|
||||
if dtype is None:
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": s.text_,
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
||||
|
||||
if dtype in [str, "str", "string"]:
|
||||
stop = ['"']
|
||||
res_json = self._make_request(
|
||||
{
|
||||
"inputs": f'{s.text_}"',
|
||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
||||
}
|
||||
)
|
||||
return (
|
||||
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
||||
{},
|
||||
)
|
||||
|
||||
if dtype in [int, "int"]:
|
||||
return (
|
||||
self.retry_for_expected(
|
||||
s.text_,
|
||||
self.adapt_params(max_tokens, stop, sampling_params),
|
||||
self._extract_int,
|
||||
),
|
||||
{},
|
||||
)
|
||||
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
@@ -18,13 +18,8 @@ try:
|
||||
except ImportError as e:
|
||||
GenerativeModel = e
|
||||
|
||||
GEMINI_MODEL_NAMES = [
|
||||
"gemini-pro",
|
||||
"gemini-pro-vision",
|
||||
]
|
||||
|
||||
|
||||
class Gemini(BaseBackend):
|
||||
class VertexAI(BaseBackend):
|
||||
def __init__(self, model_name):
|
||||
super().__init__()
|
||||
|
||||
@@ -32,7 +27,7 @@ class Gemini(BaseBackend):
|
||||
raise GenerativeModel
|
||||
|
||||
project_id = os.environ["GCP_PROJECT_ID"]
|
||||
location = os.environ["GCP_LOCATION"]
|
||||
location = os.environ.get("GCP_LOCATION")
|
||||
vertexai.init(project=project_id, location=location)
|
||||
|
||||
self.model_name = model_name
|
||||
@@ -47,17 +42,17 @@ class Gemini(BaseBackend):
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_gemini_input(s.messages_)
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_gemini_input(s.text_, s.cur_images)
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
ret = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
)
|
||||
|
||||
comp = ret.text
|
||||
@@ -70,23 +65,23 @@ class Gemini(BaseBackend):
|
||||
sampling_params: SglSamplingParams,
|
||||
):
|
||||
if s.messages_:
|
||||
prompt = self.messages_to_gemini_input(s.messages_)
|
||||
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||
else:
|
||||
# single-turn
|
||||
prompt = (
|
||||
self.text_to_gemini_input(s.text_, s.cur_images)
|
||||
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||
if s.cur_images
|
||||
else s.text_
|
||||
)
|
||||
generator = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
stream=True,
|
||||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
)
|
||||
for ret in generator:
|
||||
yield ret.text, {}
|
||||
|
||||
def text_to_gemini_input(self, text, images):
|
||||
def text_to_vertexai_input(self, text, images):
|
||||
input = []
|
||||
# split with image token
|
||||
text_segs = text.split(self.chat_template.image_token)
|
||||
@@ -100,9 +95,9 @@ class Gemini(BaseBackend):
|
||||
input.append(text_seg)
|
||||
return input
|
||||
|
||||
def messages_to_gemini_input(self, messages):
|
||||
gemini_message = []
|
||||
# from openai message format to gemini message format
|
||||
def messages_to_vertexai_input(self, messages):
|
||||
vertexai_message = []
|
||||
# from openai message format to vertexai message format
|
||||
for msg in messages:
|
||||
if isinstance(msg["content"], str):
|
||||
text = msg["content"]
|
||||
@@ -110,14 +105,14 @@ class Gemini(BaseBackend):
|
||||
text = msg["content"][0]["text"]
|
||||
|
||||
if msg["role"] == "system":
|
||||
warnings.warn("Warning: system prompt is not supported in Gemini.")
|
||||
gemini_message.append(
|
||||
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"text": "System prompt: " + text}],
|
||||
}
|
||||
)
|
||||
gemini_message.append(
|
||||
vertexai_message.append(
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [{"text": "Understood."}],
|
||||
@@ -125,12 +120,12 @@ class Gemini(BaseBackend):
|
||||
)
|
||||
continue
|
||||
if msg["role"] == "user":
|
||||
gemini_msg = {
|
||||
vertexai_msg = {
|
||||
"role": "user",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
elif msg["role"] == "assistant":
|
||||
gemini_msg = {
|
||||
vertexai_msg = {
|
||||
"role": "model",
|
||||
"parts": [{"text": text}],
|
||||
}
|
||||
@@ -139,7 +134,7 @@ class Gemini(BaseBackend):
|
||||
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||
for image in msg["content"][1:]:
|
||||
assert image["type"] == "image_url"
|
||||
gemini_msg["parts"].append(
|
||||
vertexai_msg["parts"].append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": image["image_url"]["url"].split(",")[1],
|
||||
@@ -148,5 +143,5 @@ class Gemini(BaseBackend):
|
||||
}
|
||||
)
|
||||
|
||||
gemini_message.append(gemini_msg)
|
||||
return gemini_message
|
||||
vertexai_message.append(vertexai_msg)
|
||||
return vertexai_message
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from sglang.global_config import global_config
|
||||
@@ -40,6 +41,8 @@ class SglSamplingParams:
|
||||
|
||||
def to_openai_kwargs(self):
|
||||
# OpenAI does not support top_k, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
||||
return {
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"stop": self.stop or None,
|
||||
@@ -49,7 +52,9 @@ class SglSamplingParams:
|
||||
"presence_penalty": self.presence_penalty,
|
||||
}
|
||||
|
||||
def to_gemini_kwargs(self):
|
||||
def to_vertexai_kwargs(self):
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the VertexAI backend.")
|
||||
return {
|
||||
"candidate_count": 1,
|
||||
"max_output_tokens": self.max_new_tokens,
|
||||
@@ -61,6 +66,8 @@ class SglSamplingParams:
|
||||
|
||||
def to_anthropic_kwargs(self):
|
||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||
if self.regex is not None:
|
||||
warnings.warn("Regular expression is not supported in the Anthropic backend.")
|
||||
return {
|
||||
"max_tokens_to_sample": self.max_new_tokens,
|
||||
"stop_sequences": self.stop,
|
||||
|
||||
@@ -28,7 +28,7 @@ class RouterManager:
|
||||
self.model_client = model_client
|
||||
self.recv_reqs = []
|
||||
|
||||
# Init Some Configs
|
||||
# Init some configs
|
||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||
|
||||
async def loop_for_forward(self):
|
||||
@@ -46,7 +46,7 @@ class RouterManager:
|
||||
if has_finished:
|
||||
await asyncio.sleep(self.extend_dependency_time)
|
||||
|
||||
await asyncio.sleep(0.001)
|
||||
await asyncio.sleep(0.0006)
|
||||
|
||||
async def loop_for_recv_requests(self):
|
||||
while True:
|
||||
|
||||
@@ -108,7 +108,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.running_batch: Batch = None
|
||||
self.out_pyobjs = []
|
||||
self.decode_forward_ct = 0
|
||||
self.stream_interval = 2
|
||||
self.stream_interval = server_args.stream_interval
|
||||
|
||||
# Init the FSM cache for constrained generation
|
||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||
|
||||
@@ -17,6 +17,7 @@ class ServerArgs:
|
||||
model_mode: List[str] = ()
|
||||
schedule_heuristic: str = "lpm"
|
||||
random_seed: int = 42
|
||||
stream_interval: int = 2
|
||||
disable_log_stats: bool = False
|
||||
log_stats_interval: int = 10
|
||||
log_level: str = "info"
|
||||
@@ -108,6 +109,12 @@ class ServerArgs:
|
||||
default=ServerArgs.random_seed,
|
||||
help="Random seed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stream-interval",
|
||||
type=int,
|
||||
default=ServerArgs.random_seed,
|
||||
help="The interval in terms of token length for streaming",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user