diff --git a/README.md b/README.md index 8f26d901b..5861c7a1c 100644 --- a/README.md +++ b/README.md @@ -279,8 +279,8 @@ for out in state.text_iter(): ``` ### Tips and Implementation Details -- The `choices` argument in `sgl.gen` is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability. -- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. +- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability. +- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`. ## Backend: SGLang Runtime (SRT) The SGLang Runtime (SRT) is designed to work best with the SGLang frontend. @@ -337,7 +337,6 @@ response = client.chat.completions.create( print(response) ``` - By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3. If needed, you can also override the chat template when launching the server: @@ -384,9 +383,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - Llama - Mistral - Mixtral -- Qwen / Qwen 2 -- Gemma - - Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors. +- Qwen / Qwen 2 / Qwen 2 MoE +- Gemma / Gemma 2 - `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32` - LLaVA - `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000` @@ -399,6 +397,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port - StableLM - Command-R - DBRX +- Grok +- ChatGLM - AWQ/GPTQ/Marlin quantization Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md). diff --git a/examples/usage/cot_decoding.py b/examples/usage/cot_decoding.py new file mode 100644 index 000000000..d81a813c8 --- /dev/null +++ b/examples/usage/cot_decoding.py @@ -0,0 +1,121 @@ +from math import exp +from pprint import pformat + +import sglang as sgl + +YELLOW = "\033[1;33m" +GREEN = "\033[1;32m" +BLUE = "\033[1;34m" +CLEAR = "\033[1;0m" + + +@sgl.function +def cot_decoding(s, question, get_top_k, is_chat_model, verbose): + """CoT Decoding: http://arxiv.org/abs/2402.10200""" + + if is_chat_model: + s += sgl.user("Question: " + question + "\nAnswer:") + s += sgl.assistant_begin() + else: + s += "Question: " + question + "\nAnswer:" + + step_0 = s.fork(1)[0] + forks = s.fork(get_top_k) + answer_forks = s.fork(get_top_k) + + # decoding step 0 + step_0 += sgl.gen( + "get_top_k", + max_tokens=0, + return_logprob=True, + top_logprobs_num=get_top_k, + return_text_in_logprobs=True, + ) + logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0] + + print("Decoding step 0:", + ", ".join(pformat(token[2]) for token in logprobs)) + for idx, (f, token) in enumerate(zip(forks, logprobs)): + logprob, token_id, text = token + f += text + + if text == "<|end_of_text|>": + print( + f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}" + ) + continue + + # continue greedy decoding + f += sgl.gen( + "answer", + temperature=0, + max_tokens=1024, + return_logprob=True, + top_logprobs_num=2, + return_text_in_logprobs=True, + ) + + # calculate probability disparity between the top and secondary tokens + x1s = [ + exp(xt[0][0]) + for xt in f.get_meta_info("answer")["decode_top_logprobs"] + ] + x2s = [ + exp(xt[1][0]) + for xt in f.get_meta_info("answer")["decode_top_logprobs"] + ] + tokens = [ + xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"] + ] + delta = (sum(x1s) - sum(x2s)) / len(x1s) + + # extract the answer span (without the '<|end_of_text|>' token) + answer_forks[idx] += text + f["answer"] + "\nSo the answer is" + answer_forks[idx] += sgl.gen( + "answer_span", + temperature=0, + max_tokens=64, + return_logprob=True, + top_logprobs_num=2, + return_text_in_logprobs=True, + ) + answer = answer_forks[idx]['answer_span'].replace('\n', ' ').strip(':') + print( + f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}" + ) + generated_text = str(answer_forks[idx])[len("ProgramState("):-1] + print(f"{BLUE}{pformat(generated_text)}{CLEAR}") + + if verbose: + answer_tokens = [ + xt[0][2] for xt in answer_forks[idx].get_meta_info( + "answer_span")["decode_top_logprobs"] + ] + answer_x1s = [ + exp(xt[0][0]) for xt in answer_forks[idx].get_meta_info( + "answer_span")["decode_top_logprobs"] + ] + answer_x2s = [ + exp(xt[1][0]) for xt in answer_forks[idx].get_meta_info( + "answer_span")["decode_top_logprobs"] + ] + + for token, x1, x2 in zip(tokens, x1s, x2s): + print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", + end="") + print("\n===========") + for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s): + print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})", + end="") + print() + + +sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) + +state = cot_decoding.run( + question= + r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?", + get_top_k=10, + is_chat_model=True, + verbose=False, +) diff --git a/python/sglang/api.py b/python/sglang/api.py index d21024520..043893568 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -67,10 +67,16 @@ def gen( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, dtype: Optional[type] = None, choices: Optional[List[str]] = None, regex: Optional[str] = None, ): + """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" + if choices: return SglSelect(name, choices, 0.0 if temperature is None else temperature) @@ -91,6 +97,10 @@ def gen( frequency_penalty, presence_penalty, ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, dtype, regex, ) @@ -106,6 +116,10 @@ def gen_int( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, ): return SglGen( name, @@ -117,6 +131,10 @@ def gen_int( frequency_penalty, presence_penalty, ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, int, None, ) @@ -132,6 +150,10 @@ def gen_string( frequency_penalty: Optional[float] = None, presence_penalty: Optional[float] = None, ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, ): return SglGen( name, @@ -143,6 +165,10 @@ def gen_string( frequency_penalty, presence_penalty, ignore_eos, + return_logprob, + logprob_start_len, + top_logprobs_num, + return_text_in_logprobs, str, None, ) diff --git a/python/sglang/backend/runtime_endpoint.py b/python/sglang/backend/runtime_endpoint.py index 97812941d..da27a57e9 100644 --- a/python/sglang/backend/runtime_endpoint.py +++ b/python/sglang/backend/runtime_endpoint.py @@ -12,6 +12,7 @@ from sglang.utils import http_request class RuntimeEndpoint(BaseBackend): + def __init__( self, base_url: str, @@ -37,8 +38,7 @@ class RuntimeEndpoint(BaseBackend): self.model_info = res.json() self.chat_template = get_chat_template_by_model_path( - self.model_info["model_path"] - ) + self.model_info["model_path"]) def get_model_name(self): return self.model_info["model_path"] @@ -124,6 +124,11 @@ class RuntimeEndpoint(BaseBackend): else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + value = getattr(sampling_params, item, None) + if value is not None: + data[item] = value + self._add_images(s, data) res = http_request( @@ -166,6 +171,11 @@ class RuntimeEndpoint(BaseBackend): else: raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}") + for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]: + value = getattr(sampling_params, item, None) + if value is not None: + data[item] = value + data["stream"] = True self._add_images(s, data) diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 36418a6cc..31999c400 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -668,6 +668,10 @@ class StreamExecutor: "frequency_penalty", "presence_penalty", "ignore_eos", + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + "return_text_in_logprobs", "dtype", "regex", ]: diff --git a/python/sglang/lang/ir.py b/python/sglang/lang/ir.py index ad2e9fb2b..83c6f79b0 100644 --- a/python/sglang/lang/ir.py +++ b/python/sglang/lang/ir.py @@ -23,6 +23,10 @@ class SglSamplingParams: frequency_penalty: float = 0.0 presence_penalty: float = 0.0 ignore_eos: bool = False + return_logprob: Optional[bool] = None + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, # for constrained generation, not included in to_xxx_kwargs dtype: Optional[str] = None @@ -37,6 +41,11 @@ class SglSamplingParams: self.top_k, self.frequency_penalty, self.presence_penalty, + self.ignore_eos, + self.return_logprob, + self.logprob_start_len, + self.top_logprobs_num, + self.return_text_in_logprobs, ) def to_openai_kwargs(self): @@ -139,6 +148,10 @@ class SglFunction: frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, stream: bool = False, backend=None, **kwargs, @@ -154,6 +167,10 @@ class SglFunction: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, ) backend = backend or global_config.default_backend return run_program(self, backend, args, kwargs, default_sampling_para, stream) @@ -170,6 +187,10 @@ class SglFunction: frequency_penalty: float = 0.0, presence_penalty: float = 0.0, ignore_eos: bool = False, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, backend=None, num_threads: Union[str, int] = "auto", progress_bar: bool = False, @@ -203,6 +224,10 @@ class SglFunction: frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, ) backend = backend or global_config.default_backend return run_program_batch( @@ -350,7 +375,7 @@ class SglArgument(SglExpr): class SglImage(SglExpr): - def __init__(self, path): + def __init__(self, path: str): self.path = path def __repr__(self) -> str: @@ -358,7 +383,7 @@ class SglImage(SglExpr): class SglVideo(SglExpr): - def __init__(self, path, num_frames): + def __init__(self, path: str, num_frames: int): self.path = path self.num_frames = num_frames @@ -369,18 +394,23 @@ class SglVideo(SglExpr): class SglGen(SglExpr): def __init__( self, - name, - max_new_tokens, - stop, - temperature, - top_p, - top_k, - frequency_penalty, - presence_penalty, - ignore_eos, - dtype, - regex, + name: Optional[str] = None, + max_new_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ignore_eos: Optional[bool] = None, + return_logprob: Optional[bool] = None, + logprob_start_len: Optional[int] = None, + top_logprobs_num: Optional[int] = None, + return_text_in_logprobs: Optional[bool] = None, + dtype: Optional[type] = None, + regex: Optional[str] = None, ): + """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md""" super().__init__() self.name = name self.sampling_params = SglSamplingParams( @@ -392,6 +422,10 @@ class SglGen(SglExpr): frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, ignore_eos=ignore_eos, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + return_text_in_logprobs=return_text_in_logprobs, dtype=dtype, regex=regex, ) @@ -401,7 +435,7 @@ class SglGen(SglExpr): class SglConstantText(SglExpr): - def __init__(self, value): + def __init__(self, value: str): super().__init__() self.value = value @@ -410,7 +444,7 @@ class SglConstantText(SglExpr): class SglRoleBegin(SglExpr): - def __init__(self, role): + def __init__(self, role: str): super().__init__() self.role = role @@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr): class SglRoleEnd(SglExpr): - def __init__(self, role): + def __init__(self, role: str): super().__init__() self.role = role @@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr): class SglSelect(SglExpr): - def __init__(self, name, choices, temperature): + def __init__(self, name: str, choices: List[str], temperature: float): super().__init__() self.name = name self.choices = choices @@ -439,7 +473,7 @@ class SglSelect(SglExpr): class SglFork(SglExpr): - def __init__(self, number, position_ids_offset=None): + def __init__(self, number: int, position_ids_offset=None): super().__init__() self.number = number self.position_ids_offset = position_ids_offset @@ -452,7 +486,7 @@ class SglFork(SglExpr): class SglGetForkItem(SglExpr): - def __init__(self, index): + def __init__(self, index: int): super().__init__() self.index = index @@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr): class SglVariable(SglExpr): - def __init__(self, name, source): + def __init__(self, name: str, source): super().__init__() self.name = name self.source = source @@ -471,7 +505,7 @@ class SglVariable(SglExpr): class SglVarScopeBegin(SglExpr): - def __init__(self, name): + def __init__(self, name: str): super().__init__() self.name = name @@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr): class SglVarScopeEnd(SglExpr): - def __init__(self, name): + def __init__(self, name: str): super().__init__() self.name = name @@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr): super().__init__() def __repr__(self): - return f"CommitLazy()" + return "CommitLazy()" diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 63cecdca3..bd5012904 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -333,17 +333,18 @@ class TokenizerManager: ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ) - if top_logprobs_num > 0: - ret["meta_info"][ - "prefill_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs - ) - ret["meta_info"][ - "decode_top_logprobs" - ] = self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs - ) + + if top_logprobs_num > 0: + ret["meta_info"][ + "prefill_top_logprobs" + ] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs + ) + ret["meta_info"][ + "decode_top_logprobs" + ] = self.detokenize_top_logprobs_tokens( + ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ) return ret def detokenize_logprob_tokens(self, token_logprobs, decode_to_text): @@ -383,7 +384,7 @@ def get_pixel_values( try: processor = processor or global_processor image, image_size = load_image(image_data) - if image_size != None: + if image_size is not None: image_hash = hash(image_data) pixel_values = processor.image_processor(image)["pixel_values"] for _ in range(len(pixel_values)):