From a34dd86a7dd734ef95ba37a86ba929479bbbac64 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 14 Aug 2024 08:58:07 -0700 Subject: [PATCH] Use `dtype` to control generate (#1082) Co-authored-by: zhyncs --- benchmark/json_decode_regex/bench_other.py | 8 +- benchmark/json_decode_regex/bench_sglang.py | 12 +- python/sglang/api.py | 2 +- python/sglang/bench_latency.py | 2 +- .../sglang/lang/backend/runtime_endpoint.py | 109 ++++++++++-------- python/sglang/lang/ir.py | 6 +- python/sglang/srt/managers/schedule_batch.py | 13 +-- python/sglang/srt/managers/tp_worker.py | 8 +- python/sglang/srt/models/mixtral.py | 1 + python/sglang/srt/sampling_params.py | 4 - python/sglang/test/test_programs.py | 28 ++++- test/lang/test_srt_backend.py | 5 +- 12 files changed, 110 insertions(+), 88 deletions(-) diff --git a/benchmark/json_decode_regex/bench_other.py b/benchmark/json_decode_regex/bench_other.py index bbe22835a..d80ea1de7 100644 --- a/benchmark/json_decode_regex/bench_other.py +++ b/benchmark/json_decode_regex/bench_other.py @@ -6,11 +6,11 @@ from functools import partial from tqdm import tqdm -from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING +from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate from sglang.utils import dump_state_text, read_jsonl -REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" +REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]" # fmt: off @@ -20,9 +20,9 @@ def json_decode(document, generate): s += "Here is the name, country, and symbol of the city in JSON format.\n" s += "{\n" s += ' "name": ' - s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "country": ' - s += generate(s, max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += generate(s, max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' s += generate(s, max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' diff --git a/benchmark/json_decode_regex/bench_sglang.py b/benchmark/json_decode_regex/bench_sglang.py index 196438722..462c77750 100644 --- a/benchmark/json_decode_regex/bench_sglang.py +++ b/benchmark/json_decode_regex/bench_sglang.py @@ -3,14 +3,14 @@ import json import time import sglang as sgl -from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING +from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) from sglang.utils import dump_state_text, read_jsonl -REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]" +REGEX_LIST = r"\[(" + REGEX_STR + ", )*" + REGEX_STR + r"\]" # fmt: off @sgl.function @@ -18,8 +18,8 @@ def json_warm_up(s): s += "The information about Hogwarts is in the following JSON format.\n" with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" - s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" @@ -35,8 +35,8 @@ def json_decode(s, document): s += "Here is the name, country, and symbol of the city in JSON format.\n" with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STRING + ",") + "\n" - s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STRING + ",") + "\n" + s += ' "name": ' + sgl.gen("name", max_tokens=8, regex=REGEX_STR + ",") + "\n" + s += ' "country": ' + sgl.gen("country", max_tokens=8, regex=REGEX_STR + ",") + "\n" s += ' "latitude": ' + sgl.gen("latitude", max_tokens=8, regex=REGEX_FLOAT + ",") + "\n" s += ' "population": ' + sgl.gen("population", max_tokens=8, regex=REGEX_INT + ",") + "\n" s += ' "top 3 landmarks": ' + sgl.gen( "landmarks", max_tokens=24, regex=REGEX_LIST) + "\n" diff --git a/python/sglang/api.py b/python/sglang/api.py index 5a177c36b..2242b4a4c 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -72,7 +72,7 @@ def gen( logprob_start_len: Optional[int] = None, top_logprobs_num: Optional[int] = None, return_text_in_logprobs: Optional[bool] = None, - dtype: Optional[type] = None, + dtype: Optional[Union[type, str]] = None, choices: Optional[List[str]] = None, choices_method: Optional[ChoicesSamplingMethod] = None, regex: Optional[str] = None, diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index e500d30d1..dd86747e3 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -195,7 +195,7 @@ def extend(reqs, model_runner): token_to_kv_pool=model_runner.token_to_kv_pool, tree_cache=None, ) - batch.prepare_for_extend(model_runner.model_config.vocab_size, None) + batch.prepare_for_extend(model_runner.model_config.vocab_size) output = model_runner.forward(batch, ForwardMode.EXTEND) next_token_ids = batch.sample(output.next_token_logits) return next_token_ids, output.next_token_logits, batch diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 7f0db5b35..5012f646e 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,21 +1,23 @@ import json +import warnings from typing import List, Optional from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template_by_model_path -from sglang.lang.choices import ( - ChoicesDecision, - ChoicesSamplingMethod, - token_length_normalized, -) +from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod from sglang.lang.interpreter import StreamExecutor -from sglang.lang.ir import SglSamplingParams +from sglang.lang.ir import ( + REGEX_BOOL, + REGEX_FLOAT, + REGEX_INT, + REGEX_STR, + SglSamplingParams, +) from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): - def __init__( self, base_url: str, @@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend): ) self._assert_success(res) + def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams): + if sampling_params.dtype is None: + return + + if sampling_params.stop == (): + sampling_params.stop = [] + + dtype_regex = None + if sampling_params.dtype in ["int", int]: + + dtype_regex = REGEX_INT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["float", float]: + + dtype_regex = REGEX_FLOAT + sampling_params.stop.extend([" ", "\n"]) + elif sampling_params.dtype in ["str", str]: + + dtype_regex = REGEX_STR + elif sampling_params.dtype in ["bool", bool]: + + dtype_regex = REGEX_BOOL + else: + raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + + if dtype_regex is not None and sampling_params.regex is not None: + warnings.warn( + f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}" + ) + + sampling_params.regex = dtype_regex + def generate( self, s: StreamExecutor, sampling_params: SglSamplingParams, ): - if sampling_params.dtype is None: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - **sampling_params.to_srt_kwargs(), - }, - } - elif sampling_params.dtype in [int, "int"]: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - "dtype": "int", - **sampling_params.to_srt_kwargs(), - }, - } - else: - raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + self._handle_dtype_to_regex(sampling_params) + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } for item in [ "return_logprob", @@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend): s: StreamExecutor, sampling_params: SglSamplingParams, ): - if sampling_params.dtype is None: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - **sampling_params.to_srt_kwargs(), - }, - } - elif sampling_params.dtype in [int, "int"]: - data = { - "text": s.text_, - "sampling_params": { - "skip_special_tokens": global_config.skip_special_tokens_in_output, - "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, - "dtype": "int", - **sampling_params.to_srt_kwargs(), - }, - } - else: - raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + self._handle_dtype_to_regex(sampling_params) + + data = { + "text": s.text_, + "sampling_params": { + "skip_special_tokens": global_config.skip_special_tokens_in_output, + "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out, + **sampling_params.to_srt_kwargs(), + }, + } for item in [ "return_logprob", diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index 135110c1e..0166b8687 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -8,10 +8,10 @@ from typing import List, Optional, Union from sglang.global_config import global_config from sglang.lang.choices import ChoicesSamplingMethod -REGEX_INT = r"[-+]?[0-9]+" -REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+" +REGEX_INT = r"[-+]?[0-9]+[ \n]*" +REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*" REGEX_BOOL = r"(True|False)" -REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg +REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg @dataclasses.dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index a461fa181..9037f5a6e 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -383,7 +383,7 @@ class ScheduleBatch: return out_cache_loc - def batch_sampling_params(self, vocab_size, int_token_logit_bias): + def batch_sampling_params(self, vocab_size): device = "cuda" bs, reqs = self.batch_size(), self.reqs self.temperatures = torch.tensor( @@ -419,15 +419,8 @@ class ScheduleBatch: # Handle logit bias but only allocate when needed self.logit_bias = None - for i in range(bs): - if reqs[i].sampling_params.dtype == "int": - if self.logit_bias is None: - self.logit_bias = torch.zeros( - (bs, vocab_size), dtype=torch.float32, device=device - ) - self.logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias - def prepare_for_extend(self, vocab_size: int, int_token_logit_bias: torch.Tensor): + def prepare_for_extend(self, vocab_size: int): bs = self.batch_size() reqs = self.reqs input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] @@ -466,7 +459,7 @@ class ScheduleBatch: self.out_cache_loc = out_cache_loc self.top_logprobs_nums = [r.top_logprobs_num for r in reqs] - self.batch_sampling_params(vocab_size, int_token_logit_bias) + self.batch_sampling_params(vocab_size) def check_decode_mem(self): bs = self.batch_size() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index a8b952361..4d869c591 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -54,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import ( - get_int_token_logit_bias, is_multimodal_model, set_random_seed, suppress_other_loggers, @@ -132,9 +131,6 @@ class ModelTpServer: ), self.model_runner.req_to_token_pool.size - 1, ) - self.int_token_logit_bias = torch.tensor( - get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) - ) self.max_req_input_len = min( self.model_config.context_len - 1, self.max_total_num_tokens - 1, @@ -442,9 +438,7 @@ class ModelTpServer: def forward_prefill_batch(self, batch: ScheduleBatch): # Build batch tensors - batch.prepare_for_extend( - self.model_config.vocab_size, self.int_token_logit_bias - ) + batch.prepare_for_extend(self.model_config.vocab_size) if self.model_runner.is_generation: # Forward and sample the next tokens diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 45de85d87..d11f6c951 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) diff --git a/python/sglang/srt/sampling_params.py b/python/sglang/srt/sampling_params.py index 29067dc85..6a8823cc4 100644 --- a/python/sglang/srt/sampling_params.py +++ b/python/sglang/srt/sampling_params.py @@ -36,7 +36,6 @@ class SamplingParams: ignore_eos: bool = False, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - dtype: Optional[str] = None, regex: Optional[str] = None, n: int = 1, ) -> None: @@ -53,7 +52,6 @@ class SamplingParams: self.ignore_eos = ignore_eos self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens - self.dtype = dtype self.regex = regex self.n = n @@ -63,8 +61,6 @@ class SamplingParams: self.top_k = 1 if self.top_k == -1: self.top_k = 1 << 30 # whole vocabulary - if self.dtype == "int": - self.stop_strs = [" ", "\n"] def verify(self): if self.temperature < 0.0: diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 7c7c9bdcb..6e39f0aa9 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -103,13 +103,13 @@ def test_decode_int(): def test_decode_json_regex(): @sgl.function def decode_json(s): - from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING + from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR s += "Generate a JSON object to describe the basic city information of Paris.\n" with s.var_scope("json_output"): s += "{\n" - s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n" + s += ' "name": ' + sgl.gen(regex=REGEX_STR + ",") + "\n" s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n" s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n" @@ -359,6 +359,30 @@ def test_regex(): assert re.match(regex, answer) +def test_dtype_gen(): + @sgl.function + def dtype_gen(s): + s += "Q: What is the full name of DNS?\n" + s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n" + s += "Q: Which year was DNS invented?\n" + s += "A: " + sgl.gen("int_res", dtype=int) + "\n" + s += "Q: What is the value of pi?\n" + s += "A: " + sgl.gen("float_res", dtype=float) + "\n" + s += "Q: Is the sky blue?\n" + s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n" + + state = dtype_gen.run() + + try: + state["int_res"] = int(state["int_res"]) + state["float_res"] = float(state["float_res"]) + state["bool_res"] = bool(state["bool_res"]) + # assert state["str_res"].startswith('"') and state["str_res"].endswith('"') + except ValueError: + print(state) + raise + + def test_completion_speculative(): @sgl.function(num_api_spec_tokens=64) def gen_character_spec(s): diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index b2a07ae36..fcd86ae3d 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,10 +1,10 @@ -import json import unittest import sglang as sgl from sglang.test.test_programs import ( test_decode_int, test_decode_json_regex, + test_dtype_gen, test_expert_answer, test_few_shot_qa, test_mt_bench, @@ -59,6 +59,9 @@ class TestSRTBackend(unittest.TestCase): def test_regex(self): test_regex() + def test_dtype_gen(self): + test_dtype_gen() + if __name__ == "__main__": unittest.main()