Improve docs & Rename Gemini -> VertexAI (#19)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
# SGLang
|
# SGLang
|
||||||
|
| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
|
||||||
|
|
||||||
SGLang is a structured generation language designed for large language models (LLMs).
|
SGLang is a structured generation language designed for large language models (LLMs).
|
||||||
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system.
|
||||||
@@ -42,7 +43,7 @@ The example below shows how to use sglang to answer a mulit-turn question.
|
|||||||
### Using OpenAI Models
|
### Using OpenAI Models
|
||||||
Set the OpenAI API Key
|
Set the OpenAI API Key
|
||||||
```
|
```
|
||||||
export OPENAI_API_KEY=sk-xxxxxx
|
export OPENAI_API_KEY=sk-******
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, answer a multi-turn question.
|
Then, answer a multi-turn question.
|
||||||
@@ -100,6 +101,7 @@ for m in state.messages():
|
|||||||
|
|
||||||
### More Examples
|
### More Examples
|
||||||
|
|
||||||
|
Anthropic and VertexAI (Gemini) models are also supported.
|
||||||
You can find more examples at [examples/quick_start](examples/quick_start).
|
You can find more examples at [examples/quick_start](examples/quick_start).
|
||||||
|
|
||||||
## Frontend: Structured Generation Langauge (SGLang)
|
## Frontend: Structured Generation Langauge (SGLang)
|
||||||
@@ -251,6 +253,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
- Mixtral
|
- Mixtral
|
||||||
- LLaVA
|
- LLaVA
|
||||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000`
|
||||||
|
- AWQ quantization
|
||||||
|
|
||||||
## Benchmark And Performance
|
## Benchmark And Performance
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from sglang import function, gen, set_default_backend, Gemini
|
from sglang import function, gen, set_default_backend, VertexAI
|
||||||
|
|
||||||
|
|
||||||
@function
|
@function
|
||||||
@@ -16,7 +16,7 @@ A: Rome
|
|||||||
s += "A:" + gen("answer", stop="\n", temperature=0)
|
s += "A:" + gen("answer", stop="\n", temperature=0)
|
||||||
|
|
||||||
|
|
||||||
set_default_backend(Gemini("gemini-pro"))
|
set_default_backend(VertexAI("gemini-pro"))
|
||||||
|
|
||||||
state = few_shot_qa.run(question="What is the capital of the United States?")
|
state = few_shot_qa.run(question="What is the capital of the United States?")
|
||||||
answer = state["answer"].strip().lower()
|
answer = state["answer"].strip().lower()
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from sglang import function, user, assistant, gen, image, set_default_backend, Gemini
|
from sglang import function, user, assistant, gen, image, set_default_backend, VertexAI
|
||||||
|
|
||||||
|
|
||||||
@function
|
@function
|
||||||
@@ -6,7 +6,7 @@ def image_qa(s, image_file1, image_file2, question):
|
|||||||
s += user(image(image_file1) + image(image_file2) + question)
|
s += user(image(image_file1) + image(image_file2) + question)
|
||||||
s += assistant(gen("answer_1", max_tokens=256))
|
s += assistant(gen("answer_1", max_tokens=256))
|
||||||
|
|
||||||
set_default_backend(Gemini("gemini-pro-vision"))
|
set_default_backend(VertexAI("gemini-pro-vision"))
|
||||||
|
|
||||||
state = image_qa.run(
|
state = image_qa.run(
|
||||||
image_file1="./images/cat.jpeg",
|
image_file1="./images/cat.jpeg",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from sglang import function, user, assistant, gen, set_default_backend, Gemini
|
from sglang import function, user, assistant, gen, set_default_backend, VertexAI
|
||||||
|
|
||||||
|
|
||||||
@function
|
@function
|
||||||
@@ -8,7 +8,7 @@ def multi_turn_question(s, question_1, question_2):
|
|||||||
s += user(question_2)
|
s += user(question_2)
|
||||||
s += assistant(gen("answer_2", max_tokens=256))
|
s += assistant(gen("answer_2", max_tokens=256))
|
||||||
|
|
||||||
set_default_backend(Gemini("gemini-pro"))
|
set_default_backend(VertexAI("gemini-pro"))
|
||||||
|
|
||||||
state = multi_turn_question.run(
|
state = multi_turn_question.run(
|
||||||
question_1="What is the capital of the United States?",
|
question_1="What is the capital of the United States?",
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
|
|||||||
|
|
||||||
from sglang.backend.anthropic import Anthropic
|
from sglang.backend.anthropic import Anthropic
|
||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.backend.gemini import Gemini
|
from sglang.backend.vertexai import VertexAI
|
||||||
from sglang.backend.openai import OpenAI
|
from sglang.backend.openai import OpenAI
|
||||||
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
from sglang.backend.runtime_endpoint import RuntimeEndpoint
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
|
|||||||
@@ -1,349 +0,0 @@
|
|||||||
import functools
|
|
||||||
from enum import Enum, auto
|
|
||||||
from typing import Callable, List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
from sglang.backend.base_backend import BaseBackend
|
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
|
||||||
from sglang.lang.interpreter import ProgramState
|
|
||||||
from sglang.utils import get_available_gpu_memory
|
|
||||||
from transformers import (
|
|
||||||
AutoModelForCausalLM,
|
|
||||||
AutoTokenizer,
|
|
||||||
StoppingCriteria,
|
|
||||||
StoppingCriteriaList,
|
|
||||||
)
|
|
||||||
from transformersgl.generation.logits_process import (
|
|
||||||
LogitsProcessorList,
|
|
||||||
RepetitionPenaltyLogitsProcessor,
|
|
||||||
TemperatureLogitsWarper,
|
|
||||||
TopKLogitsWarper,
|
|
||||||
TopPLogitsWarper,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class StopReason(Enum):
|
|
||||||
EOS_TOKEN = auto()
|
|
||||||
STOP_STR = auto()
|
|
||||||
LENGTH = auto()
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
model_name: str,
|
|
||||||
device,
|
|
||||||
num_gpus,
|
|
||||||
max_gpu_memory,
|
|
||||||
model_kwargs=None,
|
|
||||||
tokenizer_kwargs=None,
|
|
||||||
):
|
|
||||||
model_kwargs = model_kwargs or {}
|
|
||||||
tokenizer_kwargs = tokenizer_kwargs or {}
|
|
||||||
|
|
||||||
if device == "cuda":
|
|
||||||
model_kwargs["torch_dtype"] = torch.float16
|
|
||||||
if num_gpus != 1:
|
|
||||||
model_kwargs["device_map"] = "auto"
|
|
||||||
if max_gpu_memory is None:
|
|
||||||
model_kwargs[
|
|
||||||
"device_map"
|
|
||||||
] = "sequential" # This is important for not the same VRAM sizes
|
|
||||||
available_gpu_memory = [
|
|
||||||
get_available_gpu_memory(i, False) for i in range(num_gpus)
|
|
||||||
]
|
|
||||||
model_kwargs["max_memory"] = {
|
|
||||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
|
||||||
for i in range(num_gpus)
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
model_kwargs["max_memory"] = {
|
|
||||||
i: max_gpu_memory for i in range(num_gpus)
|
|
||||||
}
|
|
||||||
elif device == "cpu":
|
|
||||||
model_kwargs["torch_dtype"] = torch.float32
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid device: {device}")
|
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name, low_cpu_mem_usage=True, **model_kwargs
|
|
||||||
)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
|
|
||||||
|
|
||||||
if num_gpus == 1:
|
|
||||||
model.to(device).eval()
|
|
||||||
|
|
||||||
return model, tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_logits_processor(
|
|
||||||
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
|
||||||
) -> LogitsProcessorList:
|
|
||||||
processor_list = LogitsProcessorList()
|
|
||||||
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
|
||||||
if temperature >= 1e-5 and temperature != 1.0:
|
|
||||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
|
||||||
if repetition_penalty > 1.0:
|
|
||||||
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
|
||||||
if 1e-8 <= top_p < 1.0:
|
|
||||||
processor_list.append(TopPLogitsWarper(top_p))
|
|
||||||
if top_k > 0:
|
|
||||||
processor_list.append(TopKLogitsWarper(top_k))
|
|
||||||
return processor_list
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def get_token_healing_mask(tokenizer, prompt_last_token):
|
|
||||||
last_str = tokenizer.convert_ids_to_tokens(prompt_last_token)
|
|
||||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
|
||||||
for s, t_id in tokenizer.get_vocab().items():
|
|
||||||
if not s.startswith(last_str):
|
|
||||||
disallowed[t_id] = 1
|
|
||||||
return disallowed
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
|
||||||
def get_int_token_mask(tokenizer):
|
|
||||||
disallowed = torch.zeros(len(tokenizer), dtype=bool)
|
|
||||||
for s, t_id in tokenizer.get_vocab().items():
|
|
||||||
s = s.replace("▁", "").strip()
|
|
||||||
if not (s.isdigit() or len(s) == 0 or s == ","):
|
|
||||||
disallowed[t_id] = 1
|
|
||||||
disallowed[tokenizer.eos_token_id] = 0
|
|
||||||
return disallowed
|
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def generate_stream(
|
|
||||||
model,
|
|
||||||
tokenizer,
|
|
||||||
prompt,
|
|
||||||
max_new_tokens,
|
|
||||||
stop: List[str],
|
|
||||||
temperature,
|
|
||||||
top_p,
|
|
||||||
token_healing,
|
|
||||||
logit_mask=None,
|
|
||||||
):
|
|
||||||
logits_processor = prepare_logits_processor(
|
|
||||||
temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0
|
|
||||||
)
|
|
||||||
device = model.device
|
|
||||||
input_ids = tokenizer.encode(prompt)
|
|
||||||
output_ids = list(input_ids)
|
|
||||||
prompt_len = len(prompt)
|
|
||||||
|
|
||||||
# Resolve stop
|
|
||||||
stop_token_ids = [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
# Token healing
|
|
||||||
token_healing = token_healing and len(input_ids) > 0
|
|
||||||
if token_healing:
|
|
||||||
token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1])
|
|
||||||
del output_ids[-1]
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
past_key_values = None
|
|
||||||
stop_reason = None
|
|
||||||
for i in range(max_new_tokens):
|
|
||||||
# Forward
|
|
||||||
if i == 0: # prefill
|
|
||||||
out = model(torch.as_tensor([output_ids], device=device), use_cache=True)
|
|
||||||
else: # decoding
|
|
||||||
out = model(
|
|
||||||
input_ids=torch.as_tensor([[token]], device=device),
|
|
||||||
use_cache=True,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
)
|
|
||||||
logits = out.logits
|
|
||||||
past_key_values = out.past_key_values
|
|
||||||
|
|
||||||
# Logit mask
|
|
||||||
if token_healing and i == 0:
|
|
||||||
logits[0, -1, token_healing_mask] = -1e4
|
|
||||||
if logit_mask is not None:
|
|
||||||
logits[0, -1, logit_mask] = -1e4
|
|
||||||
|
|
||||||
# Sample next token
|
|
||||||
last_token_logits = logits_processor(None, logits[:, -1, :])[0]
|
|
||||||
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
|
||||||
token = int(torch.argmax(last_token_logits))
|
|
||||||
else:
|
|
||||||
probs = torch.softmax(last_token_logits, dim=-1)
|
|
||||||
token = int(torch.multinomial(probs, num_samples=1))
|
|
||||||
output_ids.append(token)
|
|
||||||
|
|
||||||
# Stop condition
|
|
||||||
if token in stop_token_ids:
|
|
||||||
stop_reason = StopReason.EOS_TOKEN
|
|
||||||
break
|
|
||||||
|
|
||||||
output_str = tokenizer.decode(output_ids, skip_special_tokens=True)
|
|
||||||
for stop_str in stop:
|
|
||||||
pos = output_str[prompt_len:].find(stop_str)
|
|
||||||
if pos != -1:
|
|
||||||
stop_reason = StopReason.STOP_STR
|
|
||||||
output_str = output_str[: prompt_len + pos]
|
|
||||||
break
|
|
||||||
|
|
||||||
if stop_reason:
|
|
||||||
break
|
|
||||||
|
|
||||||
return output_str[prompt_len:]
|
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceTransformers(BaseBackend):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_name,
|
|
||||||
device="cuda",
|
|
||||||
num_gpus=1,
|
|
||||||
max_gpu_memory=None,
|
|
||||||
model_kwargs=None,
|
|
||||||
tokenizer_kwargs=None,
|
|
||||||
):
|
|
||||||
self.model_name = model_name
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.model, self.tokenizer = load_model(
|
|
||||||
model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
self.chat_template = get_chat_template_by_model_path(model_name)
|
|
||||||
|
|
||||||
def get_chat_template(self):
|
|
||||||
return self.chat_template
|
|
||||||
|
|
||||||
def cache_prefix(self, prefix_str: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def uncache_prefix(self, rid: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def end_request(self, rid: str):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def begin_program(self, s: ProgramState):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def end_program(self, s: ProgramState):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def fill(self, s: ProgramState, text: str):
|
|
||||||
return False
|
|
||||||
|
|
||||||
def generate_internal(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
max_tokens: int,
|
|
||||||
stop: Union[str, List[str]],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
dtype: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if dtype is None:
|
|
||||||
comp = generate_stream(
|
|
||||||
self.model,
|
|
||||||
self.tokenizer,
|
|
||||||
prompt,
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stop=stop,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
token_healing=True,
|
|
||||||
)
|
|
||||||
elif dtype in [str, "str", "string"]:
|
|
||||||
comp = generate_stream(
|
|
||||||
self.model,
|
|
||||||
self.tokenizer,
|
|
||||||
prompt + '"',
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stop=['"'],
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
token_healing=False,
|
|
||||||
)
|
|
||||||
comp = '"' + comp + '"'
|
|
||||||
elif dtype in [int, "int"]:
|
|
||||||
logit_mask = get_int_token_mask(self.tokenizer)
|
|
||||||
comp = generate_stream(
|
|
||||||
self.model,
|
|
||||||
self.tokenizer,
|
|
||||||
prompt,
|
|
||||||
max_new_tokens=max_tokens,
|
|
||||||
stop=stop + [" ", ","],
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
token_healing=False,
|
|
||||||
logit_mask=logit_mask,
|
|
||||||
)
|
|
||||||
return comp
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
s: ProgramState,
|
|
||||||
max_tokens: int,
|
|
||||||
stop: Union[str, List[str]],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
dtype: Optional[str] = None,
|
|
||||||
):
|
|
||||||
prompt = s.text
|
|
||||||
comp = self.generate_internal(
|
|
||||||
prompt, max_tokens, stop, temperature, top_p, dtype
|
|
||||||
)
|
|
||||||
return comp
|
|
||||||
|
|
||||||
def parallel_generate(
|
|
||||||
self,
|
|
||||||
s: ProgramState,
|
|
||||||
prefixes: List[str],
|
|
||||||
join_func: Callable,
|
|
||||||
max_tokens: int,
|
|
||||||
stop: Union[str, List[str]],
|
|
||||||
temperature: float,
|
|
||||||
top_p: float,
|
|
||||||
dtype: Optional[str] = None,
|
|
||||||
):
|
|
||||||
prompt = s.text
|
|
||||||
parallel_prompts = [prompt + prefix for prefix in prefixes]
|
|
||||||
|
|
||||||
comps = []
|
|
||||||
for i in range(len(parallel_prompts)):
|
|
||||||
comps.append(
|
|
||||||
self.generate_internal(
|
|
||||||
parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
joined = join_func([p + c for p, c in zip(prefixes, comps)])
|
|
||||||
return joined, comps
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def select(
|
|
||||||
self, s: ProgramState, choices: List[str], temperature: float, top_p: float
|
|
||||||
):
|
|
||||||
loss_fct = torch.nn.CrossEntropyLoss()
|
|
||||||
prompt = s.text
|
|
||||||
|
|
||||||
prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1]
|
|
||||||
prompt_choices = [prompt + choice for choice in choices]
|
|
||||||
|
|
||||||
scores = []
|
|
||||||
for i in range(len(choices)):
|
|
||||||
choice_ids = self.tokenizer.encode(
|
|
||||||
prompt_choices[i], return_tensors="pt"
|
|
||||||
).to(self.model.device)
|
|
||||||
logits = self.model(choice_ids).logits
|
|
||||||
|
|
||||||
# score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item()
|
|
||||||
|
|
||||||
logprobs = torch.log(torch.softmax(logits, dim=-1))
|
|
||||||
idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device)
|
|
||||||
idx2 = choice_ids[0, 1:]
|
|
||||||
selected_logprobs = logprobs[0, idx1, idx2]
|
|
||||||
score = selected_logprobs.mean().item()
|
|
||||||
scores.append(score)
|
|
||||||
|
|
||||||
decision = choices[np.argmax(scores)]
|
|
||||||
return decision, scores
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
import re
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
from itertools import repeat
|
|
||||||
from typing import List, Optional, Union
|
|
||||||
|
|
||||||
from sglang.backend.base_backend import BaseBackend
|
|
||||||
from sglang.lang.chat_template import get_chat_template_by_model_path
|
|
||||||
from sglang.lang.interpreter import StreamExecutor
|
|
||||||
from sglang.lang.ir import SglSamplingParams
|
|
||||||
from sglang.utils import http_request
|
|
||||||
|
|
||||||
|
|
||||||
class TGI(BaseBackend):
|
|
||||||
def __init__(self, base_url):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.base_url = base_url
|
|
||||||
|
|
||||||
res = http_request(self.base_url + "/info")
|
|
||||||
assert res.status_code == 200
|
|
||||||
self.model_info = res.json()
|
|
||||||
self.chat_template = get_chat_template_by_model_path(
|
|
||||||
self.model_info["model_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_model_name(self):
|
|
||||||
return self.model_info["model_id"]
|
|
||||||
|
|
||||||
def get_chat_template(self):
|
|
||||||
return self.chat_template
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def adapt_params(max_tokens, stop, sampling_params, **override_params):
|
|
||||||
temperature = sampling_params.temperature
|
|
||||||
do_sample = True
|
|
||||||
if temperature == 0:
|
|
||||||
do_sample = False
|
|
||||||
temperature = None
|
|
||||||
|
|
||||||
if stop is None:
|
|
||||||
stop = []
|
|
||||||
elif isinstance(stop, str):
|
|
||||||
stop = [stop]
|
|
||||||
|
|
||||||
top_p = sampling_params.top_p
|
|
||||||
if top_p == 0:
|
|
||||||
top_p = 0.001
|
|
||||||
if top_p == 1:
|
|
||||||
top_p = 0.999
|
|
||||||
|
|
||||||
top_k = sampling_params.top_k
|
|
||||||
if top_k == -1:
|
|
||||||
top_k = None
|
|
||||||
|
|
||||||
params = {
|
|
||||||
"decoder_input_details": False,
|
|
||||||
"details": False,
|
|
||||||
"do_sample": do_sample,
|
|
||||||
"max_new_tokens": max_tokens,
|
|
||||||
"stop": stop,
|
|
||||||
"temperature": temperature,
|
|
||||||
"top_p": top_p,
|
|
||||||
"top_k": top_k,
|
|
||||||
"return_full_text": False,
|
|
||||||
}
|
|
||||||
params.update(override_params)
|
|
||||||
return params
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_int(text):
|
|
||||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
|
||||||
for word in words:
|
|
||||||
try:
|
|
||||||
int(word)
|
|
||||||
return word
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_choice(choices, text):
|
|
||||||
# FIXME: Current only support the case where the choices are single words.
|
|
||||||
words = re.split("\ |'|\/|\(|\)|\n|\.|,", text)
|
|
||||||
for word in words:
|
|
||||||
if word in choices:
|
|
||||||
return word
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _truncate_to_stop(text, stop):
|
|
||||||
# The stop sequence may not be a single token. In this case TGI will generate
|
|
||||||
# too many tokens so we need to truncate the output.
|
|
||||||
if stop:
|
|
||||||
stop = [stop] if isinstance(stop, str) else stop
|
|
||||||
for stop_seq in stop:
|
|
||||||
pos = text.find(stop_seq)
|
|
||||||
if pos != -1:
|
|
||||||
return text[:pos]
|
|
||||||
return text
|
|
||||||
|
|
||||||
def _make_request(self, params):
|
|
||||||
res = http_request(self.base_url + "/generate", json=params)
|
|
||||||
if res.status_code != 200:
|
|
||||||
raise ValueError(f"Error from TGI backend: {res.text}")
|
|
||||||
return res.json()
|
|
||||||
|
|
||||||
def retry_for_expected(self, prompt, params, extract_fn, retry=5):
|
|
||||||
# TGI does not support logis_bias (yet), so we have to use an inefficient hack.
|
|
||||||
failed = []
|
|
||||||
while retry > 0:
|
|
||||||
res_json = self._make_request(
|
|
||||||
{
|
|
||||||
"inputs": prompt,
|
|
||||||
"parameters": params,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
text = res_json["generated_text"]
|
|
||||||
try:
|
|
||||||
return extract_fn(text)
|
|
||||||
except ValueError:
|
|
||||||
retry -= 1
|
|
||||||
failed.append(text)
|
|
||||||
|
|
||||||
msg = "=" * 20 + "\n"
|
|
||||||
msg += f"Prompt:\n{prompt}\n"
|
|
||||||
msg += "=" * 20 + "\n"
|
|
||||||
for i, text in enumerate(failed):
|
|
||||||
msg += f"====== Try {i+1}:\n{text}\n"
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Model {self.model_info['model_id']} served by TGI backend does not generate"
|
|
||||||
"expected output. Please improve the prompt, increase the temperature, or "
|
|
||||||
f"use different models.\n{msg}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def select(
|
|
||||||
self,
|
|
||||||
s: StreamExecutor,
|
|
||||||
choices: List[str],
|
|
||||||
sampling_params: SglSamplingParams,
|
|
||||||
):
|
|
||||||
decision = self.retry_for_expected(
|
|
||||||
s.text_,
|
|
||||||
self.adapt_params(16, [], sampling_params),
|
|
||||||
partial(self._extract_choice, choices),
|
|
||||||
)
|
|
||||||
return decision, [1 if choice == decision else 0 for choice in choices]
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
s: StreamExecutor,
|
|
||||||
max_tokens: int,
|
|
||||||
stop: Union[str, List[str]],
|
|
||||||
sampling_params: SglSamplingParams,
|
|
||||||
dtype: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if dtype is None:
|
|
||||||
res_json = self._make_request(
|
|
||||||
{
|
|
||||||
"inputs": s.text_,
|
|
||||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return self._truncate_to_stop(res_json["generated_text"], stop), {}
|
|
||||||
|
|
||||||
if dtype in [str, "str", "string"]:
|
|
||||||
stop = ['"']
|
|
||||||
res_json = self._make_request(
|
|
||||||
{
|
|
||||||
"inputs": f'{s.text_}"',
|
|
||||||
"parameters": self.adapt_params(max_tokens, stop, sampling_params),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
'"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"',
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
if dtype in [int, "int"]:
|
|
||||||
return (
|
|
||||||
self.retry_for_expected(
|
|
||||||
s.text_,
|
|
||||||
self.adapt_params(max_tokens, stop, sampling_params),
|
|
||||||
self._extract_int,
|
|
||||||
),
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown dtype: {dtype}")
|
|
||||||
@@ -18,13 +18,8 @@ try:
|
|||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
GenerativeModel = e
|
GenerativeModel = e
|
||||||
|
|
||||||
GEMINI_MODEL_NAMES = [
|
|
||||||
"gemini-pro",
|
|
||||||
"gemini-pro-vision",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
class VertexAI(BaseBackend):
|
||||||
class Gemini(BaseBackend):
|
|
||||||
def __init__(self, model_name):
|
def __init__(self, model_name):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -32,7 +27,7 @@ class Gemini(BaseBackend):
|
|||||||
raise GenerativeModel
|
raise GenerativeModel
|
||||||
|
|
||||||
project_id = os.environ["GCP_PROJECT_ID"]
|
project_id = os.environ["GCP_PROJECT_ID"]
|
||||||
location = os.environ["GCP_LOCATION"]
|
location = os.environ.get("GCP_LOCATION")
|
||||||
vertexai.init(project=project_id, location=location)
|
vertexai.init(project=project_id, location=location)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
@@ -47,17 +42,17 @@ class Gemini(BaseBackend):
|
|||||||
sampling_params: SglSamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if s.messages_:
|
if s.messages_:
|
||||||
prompt = self.messages_to_gemini_input(s.messages_)
|
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||||
else:
|
else:
|
||||||
# single-turn
|
# single-turn
|
||||||
prompt = (
|
prompt = (
|
||||||
self.text_to_gemini_input(s.text_, s.cur_images)
|
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||||
if s.cur_images
|
if s.cur_images
|
||||||
else s.text_
|
else s.text_
|
||||||
)
|
)
|
||||||
ret = GenerativeModel(self.model_name).generate_content(
|
ret = GenerativeModel(self.model_name).generate_content(
|
||||||
prompt,
|
prompt,
|
||||||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||||
)
|
)
|
||||||
|
|
||||||
comp = ret.text
|
comp = ret.text
|
||||||
@@ -70,23 +65,23 @@ class Gemini(BaseBackend):
|
|||||||
sampling_params: SglSamplingParams,
|
sampling_params: SglSamplingParams,
|
||||||
):
|
):
|
||||||
if s.messages_:
|
if s.messages_:
|
||||||
prompt = self.messages_to_gemini_input(s.messages_)
|
prompt = self.messages_to_vertexai_input(s.messages_)
|
||||||
else:
|
else:
|
||||||
# single-turn
|
# single-turn
|
||||||
prompt = (
|
prompt = (
|
||||||
self.text_to_gemini_input(s.text_, s.cur_images)
|
self.text_to_vertexai_input(s.text_, s.cur_images)
|
||||||
if s.cur_images
|
if s.cur_images
|
||||||
else s.text_
|
else s.text_
|
||||||
)
|
)
|
||||||
generator = GenerativeModel(self.model_name).generate_content(
|
generator = GenerativeModel(self.model_name).generate_content(
|
||||||
prompt,
|
prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()),
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
yield ret.text, {}
|
yield ret.text, {}
|
||||||
|
|
||||||
def text_to_gemini_input(self, text, images):
|
def text_to_vertexai_input(self, text, images):
|
||||||
input = []
|
input = []
|
||||||
# split with image token
|
# split with image token
|
||||||
text_segs = text.split(self.chat_template.image_token)
|
text_segs = text.split(self.chat_template.image_token)
|
||||||
@@ -100,9 +95,9 @@ class Gemini(BaseBackend):
|
|||||||
input.append(text_seg)
|
input.append(text_seg)
|
||||||
return input
|
return input
|
||||||
|
|
||||||
def messages_to_gemini_input(self, messages):
|
def messages_to_vertexai_input(self, messages):
|
||||||
gemini_message = []
|
vertexai_message = []
|
||||||
# from openai message format to gemini message format
|
# from openai message format to vertexai message format
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg["content"], str):
|
if isinstance(msg["content"], str):
|
||||||
text = msg["content"]
|
text = msg["content"]
|
||||||
@@ -110,14 +105,14 @@ class Gemini(BaseBackend):
|
|||||||
text = msg["content"][0]["text"]
|
text = msg["content"][0]["text"]
|
||||||
|
|
||||||
if msg["role"] == "system":
|
if msg["role"] == "system":
|
||||||
warnings.warn("Warning: system prompt is not supported in Gemini.")
|
warnings.warn("Warning: system prompt is not supported in VertexAI.")
|
||||||
gemini_message.append(
|
vertexai_message.append(
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": [{"text": "System prompt: " + text}],
|
"parts": [{"text": "System prompt: " + text}],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
gemini_message.append(
|
vertexai_message.append(
|
||||||
{
|
{
|
||||||
"role": "model",
|
"role": "model",
|
||||||
"parts": [{"text": "Understood."}],
|
"parts": [{"text": "Understood."}],
|
||||||
@@ -125,12 +120,12 @@ class Gemini(BaseBackend):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
if msg["role"] == "user":
|
if msg["role"] == "user":
|
||||||
gemini_msg = {
|
vertexai_msg = {
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": [{"text": text}],
|
"parts": [{"text": text}],
|
||||||
}
|
}
|
||||||
elif msg["role"] == "assistant":
|
elif msg["role"] == "assistant":
|
||||||
gemini_msg = {
|
vertexai_msg = {
|
||||||
"role": "model",
|
"role": "model",
|
||||||
"parts": [{"text": text}],
|
"parts": [{"text": text}],
|
||||||
}
|
}
|
||||||
@@ -139,7 +134,7 @@ class Gemini(BaseBackend):
|
|||||||
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
if isinstance(msg["content"], list) and len(msg["content"]) > 1:
|
||||||
for image in msg["content"][1:]:
|
for image in msg["content"][1:]:
|
||||||
assert image["type"] == "image_url"
|
assert image["type"] == "image_url"
|
||||||
gemini_msg["parts"].append(
|
vertexai_msg["parts"].append(
|
||||||
{
|
{
|
||||||
"inline_data": {
|
"inline_data": {
|
||||||
"data": image["image_url"]["url"].split(",")[1],
|
"data": image["image_url"]["url"].split(",")[1],
|
||||||
@@ -148,5 +143,5 @@ class Gemini(BaseBackend):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
gemini_message.append(gemini_msg)
|
vertexai_message.append(vertexai_msg)
|
||||||
return gemini_message
|
return vertexai_message
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
@@ -40,6 +41,8 @@ class SglSamplingParams:
|
|||||||
|
|
||||||
def to_openai_kwargs(self):
|
def to_openai_kwargs(self):
|
||||||
# OpenAI does not support top_k, so we drop it here
|
# OpenAI does not support top_k, so we drop it here
|
||||||
|
if self.regex is not None:
|
||||||
|
warnings.warn("Regular expression is not supported in the OpenAI backend.")
|
||||||
return {
|
return {
|
||||||
"max_tokens": self.max_new_tokens,
|
"max_tokens": self.max_new_tokens,
|
||||||
"stop": self.stop or None,
|
"stop": self.stop or None,
|
||||||
@@ -49,7 +52,9 @@ class SglSamplingParams:
|
|||||||
"presence_penalty": self.presence_penalty,
|
"presence_penalty": self.presence_penalty,
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_gemini_kwargs(self):
|
def to_vertexai_kwargs(self):
|
||||||
|
if self.regex is not None:
|
||||||
|
warnings.warn("Regular expression is not supported in the VertexAI backend.")
|
||||||
return {
|
return {
|
||||||
"candidate_count": 1,
|
"candidate_count": 1,
|
||||||
"max_output_tokens": self.max_new_tokens,
|
"max_output_tokens": self.max_new_tokens,
|
||||||
@@ -61,6 +66,8 @@ class SglSamplingParams:
|
|||||||
|
|
||||||
def to_anthropic_kwargs(self):
|
def to_anthropic_kwargs(self):
|
||||||
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
# Anthropic does not support frequency_penalty or presence_penalty, so we drop it here
|
||||||
|
if self.regex is not None:
|
||||||
|
warnings.warn("Regular expression is not supported in the Anthropic backend.")
|
||||||
return {
|
return {
|
||||||
"max_tokens_to_sample": self.max_new_tokens,
|
"max_tokens_to_sample": self.max_new_tokens,
|
||||||
"stop_sequences": self.stop,
|
"stop_sequences": self.stop,
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class RouterManager:
|
|||||||
self.model_client = model_client
|
self.model_client = model_client
|
||||||
self.recv_reqs = []
|
self.recv_reqs = []
|
||||||
|
|
||||||
# Init Some Configs
|
# Init some configs
|
||||||
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
|
||||||
|
|
||||||
async def loop_for_forward(self):
|
async def loop_for_forward(self):
|
||||||
@@ -46,7 +46,7 @@ class RouterManager:
|
|||||||
if has_finished:
|
if has_finished:
|
||||||
await asyncio.sleep(self.extend_dependency_time)
|
await asyncio.sleep(self.extend_dependency_time)
|
||||||
|
|
||||||
await asyncio.sleep(0.001)
|
await asyncio.sleep(0.0006)
|
||||||
|
|
||||||
async def loop_for_recv_requests(self):
|
async def loop_for_recv_requests(self):
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -108,7 +108,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.running_batch: Batch = None
|
self.running_batch: Batch = None
|
||||||
self.out_pyobjs = []
|
self.out_pyobjs = []
|
||||||
self.decode_forward_ct = 0
|
self.decode_forward_ct = 0
|
||||||
self.stream_interval = 2
|
self.stream_interval = server_args.stream_interval
|
||||||
|
|
||||||
# Init the FSM cache for constrained generation
|
# Init the FSM cache for constrained generation
|
||||||
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
self.regex_fsm_cache = FSMCache(self.tokenizer)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class ServerArgs:
|
|||||||
model_mode: List[str] = ()
|
model_mode: List[str] = ()
|
||||||
schedule_heuristic: str = "lpm"
|
schedule_heuristic: str = "lpm"
|
||||||
random_seed: int = 42
|
random_seed: int = 42
|
||||||
|
stream_interval: int = 2
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
log_stats_interval: int = 10
|
log_stats_interval: int = 10
|
||||||
log_level: str = "info"
|
log_level: str = "info"
|
||||||
@@ -108,6 +109,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.random_seed,
|
default=ServerArgs.random_seed,
|
||||||
help="Random seed.",
|
help="Random seed.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stream-interval",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.random_seed,
|
||||||
|
help="The interval in terms of token length for streaming",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from sglang.test.test_programs import (
|
|||||||
test_stream,
|
test_stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang import Gemini, set_default_backend
|
from sglang import VertexAI, set_default_backend
|
||||||
|
|
||||||
|
|
||||||
class TestGeminiBackend(unittest.TestCase):
|
class TestVertexAIBackend(unittest.TestCase):
|
||||||
backend = None
|
backend = None
|
||||||
chat_backend = None
|
chat_backend = None
|
||||||
chat_vision_backend = None
|
chat_vision_backend = None
|
||||||
@@ -22,9 +22,9 @@ class TestGeminiBackend(unittest.TestCase):
|
|||||||
cls = type(self)
|
cls = type(self)
|
||||||
|
|
||||||
if cls.backend is None:
|
if cls.backend is None:
|
||||||
cls.backend = Gemini("gemini-pro")
|
cls.backend = VertexAI("gemini-pro")
|
||||||
cls.chat_backend = Gemini("gemini-pro")
|
cls.chat_backend = VertexAI("gemini-pro")
|
||||||
cls.chat_vision_backend = Gemini("gemini-pro-vision")
|
cls.chat_vision_backend = VertexAI("gemini-pro-vision")
|
||||||
|
|
||||||
def test_few_shot_qa(self):
|
def test_few_shot_qa(self):
|
||||||
set_default_backend(self.backend)
|
set_default_backend(self.backend)
|
||||||
@@ -61,6 +61,6 @@ if __name__ == "__main__":
|
|||||||
# from sglang.global_config import global_config
|
# from sglang.global_config import global_config
|
||||||
|
|
||||||
# global_config.verbosity = 2
|
# global_config.verbosity = 2
|
||||||
# t = TestGeminiBackend()
|
# t = TestVertexAIBackend()
|
||||||
# t.setUp()
|
# t.setUp()
|
||||||
# t.test_stream()
|
# t.test_stream()
|
||||||
Reference in New Issue
Block a user