[Feat] Expose logprob options to sgl.gen API (#503)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
This commit is contained in:
12
README.md
12
README.md
@@ -279,8 +279,8 @@ for out in state.text_iter():
|
|||||||
```
|
```
|
||||||
|
|
||||||
### Tips and Implementation Details
|
### Tips and Implementation Details
|
||||||
- The `choices` argument in `sgl.gen` is implemented by computing the normalized log probabilities of all choices and selecting the one with the highest probability.
|
- The `choices` argument in `sgl.gen` is implemented by computing the [token-length normalized log probabilities](https://blog.eleuther.ai/multiple-choice-normalization/) of all choices and selecting the one with the highest probability.
|
||||||
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex.
|
- The `regex` argument in `sgl.gen` is implemented through autoregressive decoding with logit bias masking, according to the constraints set by the regex. It is compatible with `temperature=0` and `temperature != 0`.
|
||||||
|
|
||||||
## Backend: SGLang Runtime (SRT)
|
## Backend: SGLang Runtime (SRT)
|
||||||
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
The SGLang Runtime (SRT) is designed to work best with the SGLang frontend.
|
||||||
@@ -337,7 +337,6 @@ response = client.chat.completions.create(
|
|||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3.
|
By default, the server uses the chat template specified in the model tokenizer from Hugging Face. It should just work for most official models such as Llama-2/Llama-3.
|
||||||
|
|
||||||
If needed, you can also override the chat template when launching the server:
|
If needed, you can also override the chat template when launching the server:
|
||||||
@@ -384,9 +383,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
- Llama
|
- Llama
|
||||||
- Mistral
|
- Mistral
|
||||||
- Mixtral
|
- Mixtral
|
||||||
- Qwen / Qwen 2
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
||||||
- Gemma
|
- Gemma / Gemma 2
|
||||||
- Please add a new flag `--attention-reduce-in-fp32` to avoid some precision errors.
|
|
||||||
- `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
|
- `python -m sglang.launch_server --model-path google/gemma-7b-it --port 30000 --attention-reduce-in-fp32`
|
||||||
- LLaVA
|
- LLaVA
|
||||||
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
- `python3 -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
||||||
@@ -399,6 +397,8 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|||||||
- StableLM
|
- StableLM
|
||||||
- Command-R
|
- Command-R
|
||||||
- DBRX
|
- DBRX
|
||||||
|
- Grok
|
||||||
|
- ChatGLM
|
||||||
- AWQ/GPTQ/Marlin quantization
|
- AWQ/GPTQ/Marlin quantization
|
||||||
|
|
||||||
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
|
Instructions for supporting a new model are [here](https://github.com/sgl-project/sglang/blob/main/docs/model_support.md).
|
||||||
|
|||||||
121
examples/usage/cot_decoding.py
Normal file
121
examples/usage/cot_decoding.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
from math import exp
|
||||||
|
from pprint import pformat
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
YELLOW = "\033[1;33m"
|
||||||
|
GREEN = "\033[1;32m"
|
||||||
|
BLUE = "\033[1;34m"
|
||||||
|
CLEAR = "\033[1;0m"
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
||||||
|
"""CoT Decoding: http://arxiv.org/abs/2402.10200"""
|
||||||
|
|
||||||
|
if is_chat_model:
|
||||||
|
s += sgl.user("Question: " + question + "\nAnswer:")
|
||||||
|
s += sgl.assistant_begin()
|
||||||
|
else:
|
||||||
|
s += "Question: " + question + "\nAnswer:"
|
||||||
|
|
||||||
|
step_0 = s.fork(1)[0]
|
||||||
|
forks = s.fork(get_top_k)
|
||||||
|
answer_forks = s.fork(get_top_k)
|
||||||
|
|
||||||
|
# decoding step 0
|
||||||
|
step_0 += sgl.gen(
|
||||||
|
"get_top_k",
|
||||||
|
max_tokens=0,
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=get_top_k,
|
||||||
|
return_text_in_logprobs=True,
|
||||||
|
)
|
||||||
|
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
|
||||||
|
|
||||||
|
print("Decoding step 0:",
|
||||||
|
", ".join(pformat(token[2]) for token in logprobs))
|
||||||
|
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
||||||
|
logprob, token_id, text = token
|
||||||
|
f += text
|
||||||
|
|
||||||
|
if text == "<|end_of_text|>":
|
||||||
|
print(
|
||||||
|
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score=nan, answer=nan){CLEAR}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# continue greedy decoding
|
||||||
|
f += sgl.gen(
|
||||||
|
"answer",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=1024,
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=2,
|
||||||
|
return_text_in_logprobs=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# calculate probability disparity between the top and secondary tokens
|
||||||
|
x1s = [
|
||||||
|
exp(xt[0][0])
|
||||||
|
for xt in f.get_meta_info("answer")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
x2s = [
|
||||||
|
exp(xt[1][0])
|
||||||
|
for xt in f.get_meta_info("answer")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
tokens = [
|
||||||
|
xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
||||||
|
|
||||||
|
# extract the answer span (without the '<|end_of_text|>' token)
|
||||||
|
answer_forks[idx] += text + f["answer"] + "\nSo the answer is"
|
||||||
|
answer_forks[idx] += sgl.gen(
|
||||||
|
"answer_span",
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=64,
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=2,
|
||||||
|
return_text_in_logprobs=True,
|
||||||
|
)
|
||||||
|
answer = answer_forks[idx]['answer_span'].replace('\n', ' ').strip(':')
|
||||||
|
print(
|
||||||
|
f"{YELLOW}Path #{idx} {pformat(text)}[{exp(logprob):.3f}] (score={delta}, answer={answer}){CLEAR}"
|
||||||
|
)
|
||||||
|
generated_text = str(answer_forks[idx])[len("ProgramState("):-1]
|
||||||
|
print(f"{BLUE}{pformat(generated_text)}{CLEAR}")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
answer_tokens = [
|
||||||
|
xt[0][2] for xt in answer_forks[idx].get_meta_info(
|
||||||
|
"answer_span")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
answer_x1s = [
|
||||||
|
exp(xt[0][0]) for xt in answer_forks[idx].get_meta_info(
|
||||||
|
"answer_span")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
answer_x2s = [
|
||||||
|
exp(xt[1][0]) for xt in answer_forks[idx].get_meta_info(
|
||||||
|
"answer_span")["decode_top_logprobs"]
|
||||||
|
]
|
||||||
|
|
||||||
|
for token, x1, x2 in zip(tokens, x1s, x2s):
|
||||||
|
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})",
|
||||||
|
end="")
|
||||||
|
print("\n===========")
|
||||||
|
for token, x1, x2 in zip(answer_tokens, answer_x1s, answer_x2s):
|
||||||
|
print(f" {GREEN}{pformat(token)}{CLEAR}({x1:.3f}-{x2:.3f})",
|
||||||
|
end="")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
||||||
|
|
||||||
|
state = cot_decoding.run(
|
||||||
|
question=
|
||||||
|
r"Claire makes a 3 egg omelet every morning for breakfast. How many dozens of eggs will she eat in 4 weeks?",
|
||||||
|
get_top_k=10,
|
||||||
|
is_chat_model=True,
|
||||||
|
verbose=False,
|
||||||
|
)
|
||||||
@@ -67,10 +67,16 @@ def gen(
|
|||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
|
return_logprob: Optional[bool] = None,
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
dtype: Optional[type] = None,
|
dtype: Optional[type] = None,
|
||||||
choices: Optional[List[str]] = None,
|
choices: Optional[List[str]] = None,
|
||||||
regex: Optional[str] = None,
|
regex: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
||||||
|
|
||||||
if choices:
|
if choices:
|
||||||
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
return SglSelect(name, choices, 0.0 if temperature is None else temperature)
|
||||||
|
|
||||||
@@ -91,6 +97,10 @@ def gen(
|
|||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
|
return_logprob,
|
||||||
|
logprob_start_len,
|
||||||
|
top_logprobs_num,
|
||||||
|
return_text_in_logprobs,
|
||||||
dtype,
|
dtype,
|
||||||
regex,
|
regex,
|
||||||
)
|
)
|
||||||
@@ -106,6 +116,10 @@ def gen_int(
|
|||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
|
return_logprob: Optional[bool] = None,
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
return SglGen(
|
return SglGen(
|
||||||
name,
|
name,
|
||||||
@@ -117,6 +131,10 @@ def gen_int(
|
|||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
|
return_logprob,
|
||||||
|
logprob_start_len,
|
||||||
|
top_logprobs_num,
|
||||||
|
return_text_in_logprobs,
|
||||||
int,
|
int,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@@ -132,6 +150,10 @@ def gen_string(
|
|||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty: Optional[float] = None,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos: Optional[bool] = None,
|
ignore_eos: Optional[bool] = None,
|
||||||
|
return_logprob: Optional[bool] = None,
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
return SglGen(
|
return SglGen(
|
||||||
name,
|
name,
|
||||||
@@ -143,6 +165,10 @@ def gen_string(
|
|||||||
frequency_penalty,
|
frequency_penalty,
|
||||||
presence_penalty,
|
presence_penalty,
|
||||||
ignore_eos,
|
ignore_eos,
|
||||||
|
return_logprob,
|
||||||
|
logprob_start_len,
|
||||||
|
top_logprobs_num,
|
||||||
|
return_text_in_logprobs,
|
||||||
str,
|
str,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from sglang.utils import http_request
|
|||||||
|
|
||||||
|
|
||||||
class RuntimeEndpoint(BaseBackend):
|
class RuntimeEndpoint(BaseBackend):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
@@ -37,8 +38,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
self.model_info = res.json()
|
self.model_info = res.json()
|
||||||
|
|
||||||
self.chat_template = get_chat_template_by_model_path(
|
self.chat_template = get_chat_template_by_model_path(
|
||||||
self.model_info["model_path"]
|
self.model_info["model_path"])
|
||||||
)
|
|
||||||
|
|
||||||
def get_model_name(self):
|
def get_model_name(self):
|
||||||
return self.model_info["model_path"]
|
return self.model_info["model_path"]
|
||||||
@@ -124,6 +124,11 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||||
|
|
||||||
|
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||||
|
value = getattr(sampling_params, item, None)
|
||||||
|
if value is not None:
|
||||||
|
data[item] = value
|
||||||
|
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
|
|
||||||
res = http_request(
|
res = http_request(
|
||||||
@@ -166,6 +171,11 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
||||||
|
|
||||||
|
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
|
||||||
|
value = getattr(sampling_params, item, None)
|
||||||
|
if value is not None:
|
||||||
|
data[item] = value
|
||||||
|
|
||||||
data["stream"] = True
|
data["stream"] = True
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
|
|
||||||
|
|||||||
@@ -668,6 +668,10 @@ class StreamExecutor:
|
|||||||
"frequency_penalty",
|
"frequency_penalty",
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
"ignore_eos",
|
"ignore_eos",
|
||||||
|
"return_logprob",
|
||||||
|
"logprob_start_len",
|
||||||
|
"top_logprobs_num",
|
||||||
|
"return_text_in_logprobs",
|
||||||
"dtype",
|
"dtype",
|
||||||
"regex",
|
"regex",
|
||||||
]:
|
]:
|
||||||
|
|||||||
@@ -23,6 +23,10 @@ class SglSamplingParams:
|
|||||||
frequency_penalty: float = 0.0
|
frequency_penalty: float = 0.0
|
||||||
presence_penalty: float = 0.0
|
presence_penalty: float = 0.0
|
||||||
ignore_eos: bool = False
|
ignore_eos: bool = False
|
||||||
|
return_logprob: Optional[bool] = None
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
|
|
||||||
# for constrained generation, not included in to_xxx_kwargs
|
# for constrained generation, not included in to_xxx_kwargs
|
||||||
dtype: Optional[str] = None
|
dtype: Optional[str] = None
|
||||||
@@ -37,6 +41,11 @@ class SglSamplingParams:
|
|||||||
self.top_k,
|
self.top_k,
|
||||||
self.frequency_penalty,
|
self.frequency_penalty,
|
||||||
self.presence_penalty,
|
self.presence_penalty,
|
||||||
|
self.ignore_eos,
|
||||||
|
self.return_logprob,
|
||||||
|
self.logprob_start_len,
|
||||||
|
self.top_logprobs_num,
|
||||||
|
self.return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def to_openai_kwargs(self):
|
def to_openai_kwargs(self):
|
||||||
@@ -139,6 +148,10 @@ class SglFunction:
|
|||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
|
return_logprob: Optional[bool] = None,
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
backend=None,
|
backend=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -154,6 +167,10 @@ class SglFunction:
|
|||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text_in_logprobs=return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
return run_program(self, backend, args, kwargs, default_sampling_para, stream)
|
||||||
@@ -170,6 +187,10 @@ class SglFunction:
|
|||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
ignore_eos: bool = False,
|
ignore_eos: bool = False,
|
||||||
|
return_logprob: Optional[bool] = None,
|
||||||
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
backend=None,
|
backend=None,
|
||||||
num_threads: Union[str, int] = "auto",
|
num_threads: Union[str, int] = "auto",
|
||||||
progress_bar: bool = False,
|
progress_bar: bool = False,
|
||||||
@@ -203,6 +224,10 @@ class SglFunction:
|
|||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text_in_logprobs=return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
backend = backend or global_config.default_backend
|
backend = backend or global_config.default_backend
|
||||||
return run_program_batch(
|
return run_program_batch(
|
||||||
@@ -350,7 +375,7 @@ class SglArgument(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglImage(SglExpr):
|
class SglImage(SglExpr):
|
||||||
def __init__(self, path):
|
def __init__(self, path: str):
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -358,7 +383,7 @@ class SglImage(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglVideo(SglExpr):
|
class SglVideo(SglExpr):
|
||||||
def __init__(self, path, num_frames):
|
def __init__(self, path: str, num_frames: int):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.num_frames = num_frames
|
self.num_frames = num_frames
|
||||||
|
|
||||||
@@ -369,18 +394,23 @@ class SglVideo(SglExpr):
|
|||||||
class SglGen(SglExpr):
|
class SglGen(SglExpr):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name,
|
name: Optional[str] = None,
|
||||||
max_new_tokens,
|
max_new_tokens: Optional[int] = None,
|
||||||
stop,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
temperature,
|
temperature: Optional[float] = None,
|
||||||
top_p,
|
top_p: Optional[float] = None,
|
||||||
top_k,
|
top_k: Optional[int] = None,
|
||||||
frequency_penalty,
|
frequency_penalty: Optional[float] = None,
|
||||||
presence_penalty,
|
presence_penalty: Optional[float] = None,
|
||||||
ignore_eos,
|
ignore_eos: Optional[bool] = None,
|
||||||
dtype,
|
return_logprob: Optional[bool] = None,
|
||||||
regex,
|
logprob_start_len: Optional[int] = None,
|
||||||
|
top_logprobs_num: Optional[int] = None,
|
||||||
|
return_text_in_logprobs: Optional[bool] = None,
|
||||||
|
dtype: Optional[type] = None,
|
||||||
|
regex: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sampling_params = SglSamplingParams(
|
self.sampling_params = SglSamplingParams(
|
||||||
@@ -392,6 +422,10 @@ class SglGen(SglExpr):
|
|||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
ignore_eos=ignore_eos,
|
ignore_eos=ignore_eos,
|
||||||
|
return_logprob=return_logprob,
|
||||||
|
logprob_start_len=logprob_start_len,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text_in_logprobs=return_text_in_logprobs,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
regex=regex,
|
regex=regex,
|
||||||
)
|
)
|
||||||
@@ -401,7 +435,7 @@ class SglGen(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglConstantText(SglExpr):
|
class SglConstantText(SglExpr):
|
||||||
def __init__(self, value):
|
def __init__(self, value: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
@@ -410,7 +444,7 @@ class SglConstantText(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglRoleBegin(SglExpr):
|
class SglRoleBegin(SglExpr):
|
||||||
def __init__(self, role):
|
def __init__(self, role: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|
||||||
@@ -419,7 +453,7 @@ class SglRoleBegin(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglRoleEnd(SglExpr):
|
class SglRoleEnd(SglExpr):
|
||||||
def __init__(self, role):
|
def __init__(self, role: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|
||||||
@@ -428,7 +462,7 @@ class SglRoleEnd(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglSelect(SglExpr):
|
class SglSelect(SglExpr):
|
||||||
def __init__(self, name, choices, temperature):
|
def __init__(self, name: str, choices: List[str], temperature: float):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.choices = choices
|
self.choices = choices
|
||||||
@@ -439,7 +473,7 @@ class SglSelect(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglFork(SglExpr):
|
class SglFork(SglExpr):
|
||||||
def __init__(self, number, position_ids_offset=None):
|
def __init__(self, number: int, position_ids_offset=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.number = number
|
self.number = number
|
||||||
self.position_ids_offset = position_ids_offset
|
self.position_ids_offset = position_ids_offset
|
||||||
@@ -452,7 +486,7 @@ class SglFork(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglGetForkItem(SglExpr):
|
class SglGetForkItem(SglExpr):
|
||||||
def __init__(self, index):
|
def __init__(self, index: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.index = index
|
self.index = index
|
||||||
|
|
||||||
@@ -461,7 +495,7 @@ class SglGetForkItem(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglVariable(SglExpr):
|
class SglVariable(SglExpr):
|
||||||
def __init__(self, name, source):
|
def __init__(self, name: str, source):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
self.source = source
|
self.source = source
|
||||||
@@ -471,7 +505,7 @@ class SglVariable(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglVarScopeBegin(SglExpr):
|
class SglVarScopeBegin(SglExpr):
|
||||||
def __init__(self, name):
|
def __init__(self, name: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@@ -480,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
|
|||||||
|
|
||||||
|
|
||||||
class SglVarScopeEnd(SglExpr):
|
class SglVarScopeEnd(SglExpr):
|
||||||
def __init__(self, name):
|
def __init__(self, name: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
@@ -502,4 +536,4 @@ class SglCommitLazy(SglExpr):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"CommitLazy()"
|
return "CommitLazy()"
|
||||||
|
|||||||
@@ -333,17 +333,18 @@ class TokenizerManager:
|
|||||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
if top_logprobs_num > 0:
|
|
||||||
ret["meta_info"][
|
if top_logprobs_num > 0:
|
||||||
"prefill_top_logprobs"
|
ret["meta_info"][
|
||||||
] = self.detokenize_top_logprobs_tokens(
|
"prefill_top_logprobs"
|
||||||
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
] = self.detokenize_top_logprobs_tokens(
|
||||||
)
|
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
|
||||||
ret["meta_info"][
|
)
|
||||||
"decode_top_logprobs"
|
ret["meta_info"][
|
||||||
] = self.detokenize_top_logprobs_tokens(
|
"decode_top_logprobs"
|
||||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
] = self.detokenize_top_logprobs_tokens(
|
||||||
)
|
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||||
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
||||||
@@ -383,7 +384,7 @@ def get_pixel_values(
|
|||||||
try:
|
try:
|
||||||
processor = processor or global_processor
|
processor = processor or global_processor
|
||||||
image, image_size = load_image(image_data)
|
image, image_size = load_image(image_data)
|
||||||
if image_size != None:
|
if image_size is not None:
|
||||||
image_hash = hash(image_data)
|
image_hash = hash(image_data)
|
||||||
pixel_values = processor.image_processor(image)["pixel_values"]
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
||||||
for _ in range(len(pixel_values)):
|
for _ in range(len(pixel_values)):
|
||||||
|
|||||||
Reference in New Issue
Block a user