From bf51ddc6e52d872700724ddc181089162811319a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Wed, 17 Jan 2024 02:54:41 -0800 Subject: [PATCH] Improve docs & Rename Gemini -> VertexAI (#19) --- README.md | 5 +- .../quick_start/gemini_example_complete.py | 4 +- .../gemini_example_multimodal_chat.py | 6 +- examples/quick_start/gemini_example_stream.py | 4 +- python/sglang/api.py | 2 +- python/sglang/backend/huggingface.py | 349 ------------------ python/sglang/backend/tgi.py | 190 ---------- .../sglang/backend/{gemini.py => vertexai.py} | 45 +-- python/sglang/lang/ir.py | 9 +- python/sglang/srt/managers/router/manager.py | 4 +- .../sglang/srt/managers/router/model_rpc.py | 2 +- python/sglang/srt/server_args.py | 7 + ...ni_backend.py => test_vertexai_backend.py} | 12 +- 13 files changed, 56 insertions(+), 583 deletions(-) delete mode 100644 python/sglang/backend/huggingface.py delete mode 100644 python/sglang/backend/tgi.py rename python/sglang/backend/{gemini.py => vertexai.py} (80%) rename test/lang/{test_gemini_backend.py => test_vertexai_backend.py} (81%) diff --git a/README.md b/README.md index a420cd2b9..ab7bad411 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # SGLang +| [**Blog**](https://lmsys.org/blog/2024-01-17-sglang/) | [**Paper**](https://arxiv.org/abs/2312.07104) | SGLang is a structured generation language designed for large language models (LLMs). It makes your interaction with LLMs faster and more controllable by co-designing the frontend language and the runtime system. @@ -42,7 +43,7 @@ The example below shows how to use sglang to answer a mulit-turn question. ### Using OpenAI Models Set the OpenAI API Key ``` -export OPENAI_API_KEY=sk-xxxxxx +export OPENAI_API_KEY=sk-****** ``` Then, answer a multi-turn question. @@ -100,6 +101,7 @@ for m in state.messages(): ### More Examples +Anthropic and VertexAI (Gemini) models are also supported. You can find more examples at [examples/quick_start](examples/quick_start). ## Frontend: Structured Generation Langauge (SGLang) @@ -251,6 +253,7 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Mixtral - LLaVA - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --port 30000` +- AWQ quantization ## Benchmark And Performance diff --git a/examples/quick_start/gemini_example_complete.py b/examples/quick_start/gemini_example_complete.py index e3fe028a1..abaaec7c9 100644 --- a/examples/quick_start/gemini_example_complete.py +++ b/examples/quick_start/gemini_example_complete.py @@ -1,4 +1,4 @@ -from sglang import function, gen, set_default_backend, Gemini +from sglang import function, gen, set_default_backend, VertexAI @function @@ -16,7 +16,7 @@ A: Rome s += "A:" + gen("answer", stop="\n", temperature=0) -set_default_backend(Gemini("gemini-pro")) +set_default_backend(VertexAI("gemini-pro")) state = few_shot_qa.run(question="What is the capital of the United States?") answer = state["answer"].strip().lower() diff --git a/examples/quick_start/gemini_example_multimodal_chat.py b/examples/quick_start/gemini_example_multimodal_chat.py index 312679a7e..ac5409a4e 100644 --- a/examples/quick_start/gemini_example_multimodal_chat.py +++ b/examples/quick_start/gemini_example_multimodal_chat.py @@ -1,4 +1,4 @@ -from sglang import function, user, assistant, gen, image, set_default_backend, Gemini +from sglang import function, user, assistant, gen, image, set_default_backend, VertexAI @function @@ -6,7 +6,7 @@ def image_qa(s, image_file1, image_file2, question): s += user(image(image_file1) + image(image_file2) + question) s += assistant(gen("answer_1", max_tokens=256)) -set_default_backend(Gemini("gemini-pro-vision")) +set_default_backend(VertexAI("gemini-pro-vision")) state = image_qa.run( image_file1="./images/cat.jpeg", @@ -16,4 +16,4 @@ state = image_qa.run( ) for out in state.text_iter(): - print(out, end="", flush=True) \ No newline at end of file + print(out, end="", flush=True) diff --git a/examples/quick_start/gemini_example_stream.py b/examples/quick_start/gemini_example_stream.py index 8416ea648..431e7115d 100644 --- a/examples/quick_start/gemini_example_stream.py +++ b/examples/quick_start/gemini_example_stream.py @@ -1,4 +1,4 @@ -from sglang import function, user, assistant, gen, set_default_backend, Gemini +from sglang import function, user, assistant, gen, set_default_backend, VertexAI @function @@ -8,7 +8,7 @@ def multi_turn_question(s, question_1, question_2): s += user(question_2) s += assistant(gen("answer_2", max_tokens=256)) -set_default_backend(Gemini("gemini-pro")) +set_default_backend(VertexAI("gemini-pro")) state = multi_turn_question.run( question_1="What is the capital of the United States?", diff --git a/python/sglang/api.py b/python/sglang/api.py index 5e519257a..ed31ec141 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union from sglang.backend.anthropic import Anthropic from sglang.backend.base_backend import BaseBackend -from sglang.backend.gemini import Gemini +from sglang.backend.vertexai import VertexAI from sglang.backend.openai import OpenAI from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.global_config import global_config diff --git a/python/sglang/backend/huggingface.py b/python/sglang/backend/huggingface.py deleted file mode 100644 index acd6e251f..000000000 --- a/python/sglang/backend/huggingface.py +++ /dev/null @@ -1,349 +0,0 @@ -import functools -from enum import Enum, auto -from typing import Callable, List, Optional, Union - -import numpy as np -import torch -import transformers -from sglang.backend.base_backend import BaseBackend -from sglang.lang.chat_template import get_chat_template_by_model_path -from sglang.lang.interpreter import ProgramState -from sglang.utils import get_available_gpu_memory -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - StoppingCriteria, - StoppingCriteriaList, -) -from transformersgl.generation.logits_process import ( - LogitsProcessorList, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, -) - - -class StopReason(Enum): - EOS_TOKEN = auto() - STOP_STR = auto() - LENGTH = auto() - - -def load_model( - model_name: str, - device, - num_gpus, - max_gpu_memory, - model_kwargs=None, - tokenizer_kwargs=None, -): - model_kwargs = model_kwargs or {} - tokenizer_kwargs = tokenizer_kwargs or {} - - if device == "cuda": - model_kwargs["torch_dtype"] = torch.float16 - if num_gpus != 1: - model_kwargs["device_map"] = "auto" - if max_gpu_memory is None: - model_kwargs[ - "device_map" - ] = "sequential" # This is important for not the same VRAM sizes - available_gpu_memory = [ - get_available_gpu_memory(i, False) for i in range(num_gpus) - ] - model_kwargs["max_memory"] = { - i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" - for i in range(num_gpus) - } - else: - model_kwargs["max_memory"] = { - i: max_gpu_memory for i in range(num_gpus) - } - elif device == "cpu": - model_kwargs["torch_dtype"] = torch.float32 - else: - raise ValueError(f"Invalid device: {device}") - - model = AutoModelForCausalLM.from_pretrained( - model_name, low_cpu_mem_usage=True, **model_kwargs - ) - tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs) - - if num_gpus == 1: - model.to(device).eval() - - return model, tokenizer - - -def prepare_logits_processor( - temperature: float, repetition_penalty: float, top_p: float, top_k: int -) -> LogitsProcessorList: - processor_list = LogitsProcessorList() - # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. - if temperature >= 1e-5 and temperature != 1.0: - processor_list.append(TemperatureLogitsWarper(temperature)) - if repetition_penalty > 1.0: - processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) - if 1e-8 <= top_p < 1.0: - processor_list.append(TopPLogitsWarper(top_p)) - if top_k > 0: - processor_list.append(TopKLogitsWarper(top_k)) - return processor_list - - -@functools.lru_cache -def get_token_healing_mask(tokenizer, prompt_last_token): - last_str = tokenizer.convert_ids_to_tokens(prompt_last_token) - disallowed = torch.zeros(len(tokenizer), dtype=bool) - for s, t_id in tokenizer.get_vocab().items(): - if not s.startswith(last_str): - disallowed[t_id] = 1 - return disallowed - - -@functools.lru_cache -def get_int_token_mask(tokenizer): - disallowed = torch.zeros(len(tokenizer), dtype=bool) - for s, t_id in tokenizer.get_vocab().items(): - s = s.replace("▁", "").strip() - if not (s.isdigit() or len(s) == 0 or s == ","): - disallowed[t_id] = 1 - disallowed[tokenizer.eos_token_id] = 0 - return disallowed - - -@torch.inference_mode() -def generate_stream( - model, - tokenizer, - prompt, - max_new_tokens, - stop: List[str], - temperature, - top_p, - token_healing, - logit_mask=None, -): - logits_processor = prepare_logits_processor( - temperature=temperature, repetition_penalty=1.0, top_p=top_p, top_k=0 - ) - device = model.device - input_ids = tokenizer.encode(prompt) - output_ids = list(input_ids) - prompt_len = len(prompt) - - # Resolve stop - stop_token_ids = [tokenizer.eos_token_id] - - # Token healing - token_healing = token_healing and len(input_ids) > 0 - if token_healing: - token_healing_mask = get_token_healing_mask(tokenizer, input_ids[-1]) - del output_ids[-1] - - # Generate - past_key_values = None - stop_reason = None - for i in range(max_new_tokens): - # Forward - if i == 0: # prefill - out = model(torch.as_tensor([output_ids], device=device), use_cache=True) - else: # decoding - out = model( - input_ids=torch.as_tensor([[token]], device=device), - use_cache=True, - past_key_values=past_key_values, - ) - logits = out.logits - past_key_values = out.past_key_values - - # Logit mask - if token_healing and i == 0: - logits[0, -1, token_healing_mask] = -1e4 - if logit_mask is not None: - logits[0, -1, logit_mask] = -1e4 - - # Sample next token - last_token_logits = logits_processor(None, logits[:, -1, :])[0] - if temperature < 1e-5 or top_p < 1e-8: # greedy - token = int(torch.argmax(last_token_logits)) - else: - probs = torch.softmax(last_token_logits, dim=-1) - token = int(torch.multinomial(probs, num_samples=1)) - output_ids.append(token) - - # Stop condition - if token in stop_token_ids: - stop_reason = StopReason.EOS_TOKEN - break - - output_str = tokenizer.decode(output_ids, skip_special_tokens=True) - for stop_str in stop: - pos = output_str[prompt_len:].find(stop_str) - if pos != -1: - stop_reason = StopReason.STOP_STR - output_str = output_str[: prompt_len + pos] - break - - if stop_reason: - break - - return output_str[prompt_len:] - - -class HuggingFaceTransformers(BaseBackend): - def __init__( - self, - model_name, - device="cuda", - num_gpus=1, - max_gpu_memory=None, - model_kwargs=None, - tokenizer_kwargs=None, - ): - self.model_name = model_name - self.device = device - - self.model, self.tokenizer = load_model( - model_name, device, num_gpus, max_gpu_memory, model_kwargs, tokenizer_kwargs - ) - - self.chat_template = get_chat_template_by_model_path(model_name) - - def get_chat_template(self): - return self.chat_template - - def cache_prefix(self, prefix_str: str): - pass - - def uncache_prefix(self, rid: str): - pass - - def end_request(self, rid: str): - pass - - def begin_program(self, s: ProgramState): - pass - - def end_program(self, s: ProgramState): - pass - - def fill(self, s: ProgramState, text: str): - return False - - def generate_internal( - self, - prompt: str, - max_tokens: int, - stop: Union[str, List[str]], - temperature: float, - top_p: float, - dtype: Optional[str] = None, - ): - if dtype is None: - comp = generate_stream( - self.model, - self.tokenizer, - prompt, - max_new_tokens=max_tokens, - stop=stop, - temperature=temperature, - top_p=top_p, - token_healing=True, - ) - elif dtype in [str, "str", "string"]: - comp = generate_stream( - self.model, - self.tokenizer, - prompt + '"', - max_new_tokens=max_tokens, - stop=['"'], - temperature=temperature, - top_p=top_p, - token_healing=False, - ) - comp = '"' + comp + '"' - elif dtype in [int, "int"]: - logit_mask = get_int_token_mask(self.tokenizer) - comp = generate_stream( - self.model, - self.tokenizer, - prompt, - max_new_tokens=max_tokens, - stop=stop + [" ", ","], - temperature=temperature, - top_p=top_p, - token_healing=False, - logit_mask=logit_mask, - ) - return comp - - def generate( - self, - s: ProgramState, - max_tokens: int, - stop: Union[str, List[str]], - temperature: float, - top_p: float, - dtype: Optional[str] = None, - ): - prompt = s.text - comp = self.generate_internal( - prompt, max_tokens, stop, temperature, top_p, dtype - ) - return comp - - def parallel_generate( - self, - s: ProgramState, - prefixes: List[str], - join_func: Callable, - max_tokens: int, - stop: Union[str, List[str]], - temperature: float, - top_p: float, - dtype: Optional[str] = None, - ): - prompt = s.text - parallel_prompts = [prompt + prefix for prefix in prefixes] - - comps = [] - for i in range(len(parallel_prompts)): - comps.append( - self.generate_internal( - parallel_prompts[i], max_tokens, stop, temperature, top_p, dtype - ) - ) - - joined = join_func([p + c for p, c in zip(prefixes, comps)]) - return joined, comps - - @torch.inference_mode() - def select( - self, s: ProgramState, choices: List[str], temperature: float, top_p: float - ): - loss_fct = torch.nn.CrossEntropyLoss() - prompt = s.text - - prompt_len = self.tokenizer.encode(prompt, return_tensors="pt").shape[1] - prompt_choices = [prompt + choice for choice in choices] - - scores = [] - for i in range(len(choices)): - choice_ids = self.tokenizer.encode( - prompt_choices[i], return_tensors="pt" - ).to(self.model.device) - logits = self.model(choice_ids).logits - - # score = -loss_fct(logits[0, :-1, :], choice_ids[0, 1:]).item() - - logprobs = torch.log(torch.softmax(logits, dim=-1)) - idx1 = torch.arange(0, logits.shape[1] - 1, device=logits.device) - idx2 = choice_ids[0, 1:] - selected_logprobs = logprobs[0, idx1, idx2] - score = selected_logprobs.mean().item() - scores.append(score) - - decision = choices[np.argmax(scores)] - return decision, scores diff --git a/python/sglang/backend/tgi.py b/python/sglang/backend/tgi.py deleted file mode 100644 index be3f3fea4..000000000 --- a/python/sglang/backend/tgi.py +++ /dev/null @@ -1,190 +0,0 @@ -import re -from concurrent.futures import ThreadPoolExecutor -from functools import partial -from itertools import repeat -from typing import List, Optional, Union - -from sglang.backend.base_backend import BaseBackend -from sglang.lang.chat_template import get_chat_template_by_model_path -from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SglSamplingParams -from sglang.utils import http_request - - -class TGI(BaseBackend): - def __init__(self, base_url): - super().__init__() - - self.base_url = base_url - - res = http_request(self.base_url + "/info") - assert res.status_code == 200 - self.model_info = res.json() - self.chat_template = get_chat_template_by_model_path( - self.model_info["model_id"] - ) - - def get_model_name(self): - return self.model_info["model_id"] - - def get_chat_template(self): - return self.chat_template - - @staticmethod - def adapt_params(max_tokens, stop, sampling_params, **override_params): - temperature = sampling_params.temperature - do_sample = True - if temperature == 0: - do_sample = False - temperature = None - - if stop is None: - stop = [] - elif isinstance(stop, str): - stop = [stop] - - top_p = sampling_params.top_p - if top_p == 0: - top_p = 0.001 - if top_p == 1: - top_p = 0.999 - - top_k = sampling_params.top_k - if top_k == -1: - top_k = None - - params = { - "decoder_input_details": False, - "details": False, - "do_sample": do_sample, - "max_new_tokens": max_tokens, - "stop": stop, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "return_full_text": False, - } - params.update(override_params) - return params - - @staticmethod - def _extract_int(text): - words = re.split("\ |'|\/|\(|\)|\n|\.|,", text) - for word in words: - try: - int(word) - return word - except ValueError: - continue - raise ValueError - - @staticmethod - def _extract_choice(choices, text): - # FIXME: Current only support the case where the choices are single words. - words = re.split("\ |'|\/|\(|\)|\n|\.|,", text) - for word in words: - if word in choices: - return word - raise ValueError - - @staticmethod - def _truncate_to_stop(text, stop): - # The stop sequence may not be a single token. In this case TGI will generate - # too many tokens so we need to truncate the output. - if stop: - stop = [stop] if isinstance(stop, str) else stop - for stop_seq in stop: - pos = text.find(stop_seq) - if pos != -1: - return text[:pos] - return text - - def _make_request(self, params): - res = http_request(self.base_url + "/generate", json=params) - if res.status_code != 200: - raise ValueError(f"Error from TGI backend: {res.text}") - return res.json() - - def retry_for_expected(self, prompt, params, extract_fn, retry=5): - # TGI does not support logis_bias (yet), so we have to use an inefficient hack. - failed = [] - while retry > 0: - res_json = self._make_request( - { - "inputs": prompt, - "parameters": params, - } - ) - text = res_json["generated_text"] - try: - return extract_fn(text) - except ValueError: - retry -= 1 - failed.append(text) - - msg = "=" * 20 + "\n" - msg += f"Prompt:\n{prompt}\n" - msg += "=" * 20 + "\n" - for i, text in enumerate(failed): - msg += f"====== Try {i+1}:\n{text}\n" - - raise ValueError( - f"Model {self.model_info['model_id']} served by TGI backend does not generate" - "expected output. Please improve the prompt, increase the temperature, or " - f"use different models.\n{msg}" - ) - - def select( - self, - s: StreamExecutor, - choices: List[str], - sampling_params: SglSamplingParams, - ): - decision = self.retry_for_expected( - s.text_, - self.adapt_params(16, [], sampling_params), - partial(self._extract_choice, choices), - ) - return decision, [1 if choice == decision else 0 for choice in choices] - - def generate( - self, - s: StreamExecutor, - max_tokens: int, - stop: Union[str, List[str]], - sampling_params: SglSamplingParams, - dtype: Optional[str] = None, - ): - if dtype is None: - res_json = self._make_request( - { - "inputs": s.text_, - "parameters": self.adapt_params(max_tokens, stop, sampling_params), - } - ) - return self._truncate_to_stop(res_json["generated_text"], stop), {} - - if dtype in [str, "str", "string"]: - stop = ['"'] - res_json = self._make_request( - { - "inputs": f'{s.text_}"', - "parameters": self.adapt_params(max_tokens, stop, sampling_params), - } - ) - return ( - '"' + self._truncate_to_stop(res_json["generated_text"], stop) + '"', - {}, - ) - - if dtype in [int, "int"]: - return ( - self.retry_for_expected( - s.text_, - self.adapt_params(max_tokens, stop, sampling_params), - self._extract_int, - ), - {}, - ) - - raise ValueError(f"Unknown dtype: {dtype}") diff --git a/python/sglang/backend/gemini.py b/python/sglang/backend/vertexai.py similarity index 80% rename from python/sglang/backend/gemini.py rename to python/sglang/backend/vertexai.py index 3ce10cf4e..5c3c307e2 100644 --- a/python/sglang/backend/gemini.py +++ b/python/sglang/backend/vertexai.py @@ -18,13 +18,8 @@ try: except ImportError as e: GenerativeModel = e -GEMINI_MODEL_NAMES = [ - "gemini-pro", - "gemini-pro-vision", -] - -class Gemini(BaseBackend): +class VertexAI(BaseBackend): def __init__(self, model_name): super().__init__() @@ -32,7 +27,7 @@ class Gemini(BaseBackend): raise GenerativeModel project_id = os.environ["GCP_PROJECT_ID"] - location = os.environ["GCP_LOCATION"] + location = os.environ.get("GCP_LOCATION") vertexai.init(project=project_id, location=location) self.model_name = model_name @@ -47,17 +42,17 @@ class Gemini(BaseBackend): sampling_params: SglSamplingParams, ): if s.messages_: - prompt = self.messages_to_gemini_input(s.messages_) + prompt = self.messages_to_vertexai_input(s.messages_) else: # single-turn prompt = ( - self.text_to_gemini_input(s.text_, s.cur_images) + self.text_to_vertexai_input(s.text_, s.cur_images) if s.cur_images else s.text_ ) ret = GenerativeModel(self.model_name).generate_content( prompt, - generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), + generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), ) comp = ret.text @@ -70,23 +65,23 @@ class Gemini(BaseBackend): sampling_params: SglSamplingParams, ): if s.messages_: - prompt = self.messages_to_gemini_input(s.messages_) + prompt = self.messages_to_vertexai_input(s.messages_) else: # single-turn prompt = ( - self.text_to_gemini_input(s.text_, s.cur_images) + self.text_to_vertexai_input(s.text_, s.cur_images) if s.cur_images else s.text_ ) generator = GenerativeModel(self.model_name).generate_content( prompt, stream=True, - generation_config=GenerationConfig(**sampling_params.to_gemini_kwargs()), + generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), ) for ret in generator: yield ret.text, {} - def text_to_gemini_input(self, text, images): + def text_to_vertexai_input(self, text, images): input = [] # split with image token text_segs = text.split(self.chat_template.image_token) @@ -100,9 +95,9 @@ class Gemini(BaseBackend): input.append(text_seg) return input - def messages_to_gemini_input(self, messages): - gemini_message = [] - # from openai message format to gemini message format + def messages_to_vertexai_input(self, messages): + vertexai_message = [] + # from openai message format to vertexai message format for msg in messages: if isinstance(msg["content"], str): text = msg["content"] @@ -110,14 +105,14 @@ class Gemini(BaseBackend): text = msg["content"][0]["text"] if msg["role"] == "system": - warnings.warn("Warning: system prompt is not supported in Gemini.") - gemini_message.append( + warnings.warn("Warning: system prompt is not supported in VertexAI.") + vertexai_message.append( { "role": "user", "parts": [{"text": "System prompt: " + text}], } ) - gemini_message.append( + vertexai_message.append( { "role": "model", "parts": [{"text": "Understood."}], @@ -125,12 +120,12 @@ class Gemini(BaseBackend): ) continue if msg["role"] == "user": - gemini_msg = { + vertexai_msg = { "role": "user", "parts": [{"text": text}], } elif msg["role"] == "assistant": - gemini_msg = { + vertexai_msg = { "role": "model", "parts": [{"text": text}], } @@ -139,7 +134,7 @@ class Gemini(BaseBackend): if isinstance(msg["content"], list) and len(msg["content"]) > 1: for image in msg["content"][1:]: assert image["type"] == "image_url" - gemini_msg["parts"].append( + vertexai_msg["parts"].append( { "inline_data": { "data": image["image_url"]["url"].split(",")[1], @@ -148,5 +143,5 @@ class Gemini(BaseBackend): } ) - gemini_message.append(gemini_msg) - return gemini_message + vertexai_message.append(vertexai_msg) + return vertexai_message diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index bafdea43e..b6c1b9b54 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -2,6 +2,7 @@ import dataclasses import inspect +import warnings from typing import List, Optional, Union from sglang.global_config import global_config @@ -40,6 +41,8 @@ class SglSamplingParams: def to_openai_kwargs(self): # OpenAI does not support top_k, so we drop it here + if self.regex is not None: + warnings.warn("Regular expression is not supported in the OpenAI backend.") return { "max_tokens": self.max_new_tokens, "stop": self.stop or None, @@ -49,7 +52,9 @@ class SglSamplingParams: "presence_penalty": self.presence_penalty, } - def to_gemini_kwargs(self): + def to_vertexai_kwargs(self): + if self.regex is not None: + warnings.warn("Regular expression is not supported in the VertexAI backend.") return { "candidate_count": 1, "max_output_tokens": self.max_new_tokens, @@ -61,6 +66,8 @@ class SglSamplingParams: def to_anthropic_kwargs(self): # Anthropic does not support frequency_penalty or presence_penalty, so we drop it here + if self.regex is not None: + warnings.warn("Regular expression is not supported in the Anthropic backend.") return { "max_tokens_to_sample": self.max_new_tokens, "stop_sequences": self.stop, diff --git a/python/sglang/srt/managers/router/manager.py b/python/sglang/srt/managers/router/manager.py index 9d848b9a7..0732d0fa8 100644 --- a/python/sglang/srt/managers/router/manager.py +++ b/python/sglang/srt/managers/router/manager.py @@ -28,7 +28,7 @@ class RouterManager: self.model_client = model_client self.recv_reqs = [] - # Init Some Configs + # Init some configs self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time async def loop_for_forward(self): @@ -46,7 +46,7 @@ class RouterManager: if has_finished: await asyncio.sleep(self.extend_dependency_time) - await asyncio.sleep(0.001) + await asyncio.sleep(0.0006) async def loop_for_recv_requests(self): while True: diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 09f98d1dd..877afd749 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -108,7 +108,7 @@ class ModelRpcServer(rpyc.Service): self.running_batch: Batch = None self.out_pyobjs = [] self.decode_forward_ct = 0 - self.stream_interval = 2 + self.stream_interval = server_args.stream_interval # Init the FSM cache for constrained generation self.regex_fsm_cache = FSMCache(self.tokenizer) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 754f70f4a..7c2957abc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,6 +17,7 @@ class ServerArgs: model_mode: List[str] = () schedule_heuristic: str = "lpm" random_seed: int = 42 + stream_interval: int = 2 disable_log_stats: bool = False log_stats_interval: int = 10 log_level: str = "info" @@ -108,6 +109,12 @@ class ServerArgs: default=ServerArgs.random_seed, help="Random seed.", ) + parser.add_argument( + "--stream-interval", + type=int, + default=ServerArgs.random_seed, + help="The interval in terms of token length for streaming", + ) parser.add_argument( "--log-level", type=str, diff --git a/test/lang/test_gemini_backend.py b/test/lang/test_vertexai_backend.py similarity index 81% rename from test/lang/test_gemini_backend.py rename to test/lang/test_vertexai_backend.py index f2e1e83a7..a17ab4ba7 100644 --- a/test/lang/test_gemini_backend.py +++ b/test/lang/test_vertexai_backend.py @@ -10,10 +10,10 @@ from sglang.test.test_programs import ( test_stream, ) -from sglang import Gemini, set_default_backend +from sglang import VertexAI, set_default_backend -class TestGeminiBackend(unittest.TestCase): +class TestVertexAIBackend(unittest.TestCase): backend = None chat_backend = None chat_vision_backend = None @@ -22,9 +22,9 @@ class TestGeminiBackend(unittest.TestCase): cls = type(self) if cls.backend is None: - cls.backend = Gemini("gemini-pro") - cls.chat_backend = Gemini("gemini-pro") - cls.chat_vision_backend = Gemini("gemini-pro-vision") + cls.backend = VertexAI("gemini-pro") + cls.chat_backend = VertexAI("gemini-pro") + cls.chat_vision_backend = VertexAI("gemini-pro-vision") def test_few_shot_qa(self): set_default_backend(self.backend) @@ -61,6 +61,6 @@ if __name__ == "__main__": # from sglang.global_config import global_config # global_config.verbosity = 2 - # t = TestGeminiBackend() + # t = TestVertexAIBackend() # t.setUp() # t.test_stream()