Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -13,7 +13,7 @@ class GenerateReqInput:
|
|||||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
# See also python/sglang/srt/utils.py:load_image.
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
@@ -23,7 +23,7 @@ class GenerateReqInput:
|
|||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# The number of top logprobs to return.
|
# The number of top logprobs to return.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
# Whether to detokenize tokens in logprobs.
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output.
|
# Whether to stream output.
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
@@ -32,27 +32,28 @@ class GenerateReqInput:
|
|||||||
The `sampling_params` follows this format
|
The `sampling_params` follows this format
|
||||||
|
|
||||||
```python
|
```python
|
||||||
class SamplingParams:
|
# The maximum number of output tokens
|
||||||
def __init__(
|
max_new_tokens: int = 16,
|
||||||
self,
|
# Stop when hitting any of the strings in this list.
|
||||||
max_new_tokens: int = 16,
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
stop: Optional[Union[str, List[str]]] = None,
|
# Sampling temperature
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
# Top-p sampling
|
||||||
top_k: int = -1,
|
top_p: float = 1.0,
|
||||||
frequency_penalty: float = 0.0,
|
# Top-k sampling
|
||||||
presence_penalty: float = 0.0,
|
top_k: int = -1,
|
||||||
ignore_eos: bool = False,
|
# Whether to ignore EOS token.
|
||||||
skip_special_tokens: bool = True,
|
ignore_eos: bool = False,
|
||||||
dtype: Optional[str] = None,
|
# Whether to skip the special tokens during detokenization.
|
||||||
regex: Optional[str] = None,
|
skip_special_tokens: bool = True,
|
||||||
) -> None:
|
# Whether to add spaces between special tokens during detokenization.
|
||||||
|
spaces_between_special_tokens: bool = True,
|
||||||
|
# Constrains the output to follow a given regular expression.
|
||||||
|
regex: Optional[str] = None,
|
||||||
|
# Do parallel sampling and return `n` outputs.
|
||||||
|
n: int = 1,
|
||||||
```
|
```
|
||||||
|
|
||||||
- `max_new_tokens`, `stop`, `temperature`, `top_p`, `top_k` are common sampling parameters.
|
|
||||||
- `ignore_eos` means ignoring the EOS token and continue decoding, which is helpful for benchmarking purposes.
|
|
||||||
- `regex` constrains the output to follow a given regular expression.
|
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### Normal
|
### Normal
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ def main():
|
|||||||
print("questions:", question)
|
print("questions:", question)
|
||||||
print("choice:", state["tool"])
|
print("choice:", state["tool"])
|
||||||
meta_info = state.get_meta_info("tool")
|
meta_info = state.get_meta_info("tool")
|
||||||
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
||||||
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
# Run a batch
|
# Run a batch
|
||||||
@@ -34,8 +34,8 @@ def main():
|
|||||||
print("questions:", question)
|
print("questions:", question)
|
||||||
print("choice:", state["tool"])
|
print("choice:", state["tool"])
|
||||||
meta_info = state.get_meta_info("tool")
|
meta_info = state.get_meta_info("tool")
|
||||||
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
print("logprobs of choice 1", meta_info["input_token_logprobs"][0])
|
||||||
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
print("logprobs of choice 2", meta_info["input_token_logprobs"][1])
|
||||||
print("-" * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
|||||||
top_logprobs_num=get_top_k,
|
top_logprobs_num=get_top_k,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
)
|
)
|
||||||
logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0]
|
logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0]
|
||||||
|
|
||||||
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
|
print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs))
|
||||||
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
for idx, (f, token) in enumerate(zip(forks, logprobs)):
|
||||||
@@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# calculate probability disparity between the top and secondary tokens
|
# calculate probability disparity between the top and secondary tokens
|
||||||
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||||
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||||
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]]
|
tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]]
|
||||||
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
delta = (sum(x1s) - sum(x2s)) / len(x1s)
|
||||||
|
|
||||||
# extract the answer span (without the '<|end_of_text|>' token)
|
# extract the answer span (without the '<|end_of_text|>' token)
|
||||||
@@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose):
|
|||||||
answer_tokens = [
|
answer_tokens = [
|
||||||
xt[0][2]
|
xt[0][2]
|
||||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||||
"decode_top_logprobs"
|
"output_top_logprobs"
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
answer_x1s = [
|
answer_x1s = [
|
||||||
exp(xt[0][0])
|
exp(xt[0][0])
|
||||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||||
"decode_top_logprobs"
|
"output_top_logprobs"
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
answer_x2s = [
|
answer_x2s = [
|
||||||
exp(xt[1][0])
|
exp(xt[1][0])
|
||||||
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
for xt in answer_forks[idx].get_meta_info("answer_span")[
|
||||||
"decode_top_logprobs"
|
"output_top_logprobs"
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -56,14 +56,14 @@ def srt_api_request(name):
|
|||||||
# fout.write(json.dumps(res, indent=4))
|
# fout.write(json.dumps(res, indent=4))
|
||||||
|
|
||||||
meta_info = res["meta_info"]
|
meta_info = res["meta_info"]
|
||||||
assert len(meta_info["prefill_token_logprobs"]) == len(
|
assert len(meta_info["input_token_logprobs"]) == len(
|
||||||
meta_info["prefill_top_logprobs"]
|
meta_info["input_top_logprobs"]
|
||||||
)
|
)
|
||||||
assert len(meta_info["decode_token_logprobs"]) == len(
|
assert len(meta_info["output_token_logprobs"]) == len(
|
||||||
meta_info["decode_top_logprobs"]
|
meta_info["output_top_logprobs"]
|
||||||
)
|
)
|
||||||
assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"]
|
assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"]
|
||||||
assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@@ -72,11 +72,11 @@ def pretty_print(res):
|
|||||||
meta_info = res["meta_info"]
|
meta_info = res["meta_info"]
|
||||||
|
|
||||||
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
print("\n\n", "=" * 30, "Prefill", "=" * 30)
|
||||||
for i in range(len(meta_info["prefill_token_logprobs"])):
|
for i in range(len(meta_info["input_token_logprobs"])):
|
||||||
print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="")
|
print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="")
|
||||||
top_ks = (
|
top_ks = (
|
||||||
[str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]]
|
[str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]]
|
||||||
if meta_info["prefill_top_logprobs"][i]
|
if meta_info["input_top_logprobs"][i]
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
for top_k in top_ks:
|
for top_k in top_ks:
|
||||||
@@ -84,9 +84,9 @@ def pretty_print(res):
|
|||||||
print()
|
print()
|
||||||
|
|
||||||
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
print("\n\n", "=" * 30, "Decode", "=" * 30)
|
||||||
for i in range(len(meta_info["decode_token_logprobs"])):
|
for i in range(len(meta_info["output_token_logprobs"])):
|
||||||
print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="")
|
print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="")
|
||||||
top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]]
|
top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]]
|
||||||
for top_k in top_ks:
|
for top_k in top_ks:
|
||||||
print(f"{top_k: <15}", end="")
|
print(f"{top_k: <15}", end="")
|
||||||
print()
|
print()
|
||||||
|
|||||||
@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||||
]
|
]
|
||||||
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
||||||
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||||
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||||
|
|
||||||
return (
|
return (
|
||||||
decision,
|
decision,
|
||||||
normalized_prompt_logprobs,
|
normalized_prompt_logprobs,
|
||||||
prefill_token_logprobs,
|
input_token_logprobs,
|
||||||
decode_token_logprobs,
|
output_token_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||||
|
|||||||
@@ -541,16 +541,16 @@ class StreamExecutor:
|
|||||||
(
|
(
|
||||||
decision,
|
decision,
|
||||||
normalized_prompt_logprobs,
|
normalized_prompt_logprobs,
|
||||||
prefill_token_logprobs,
|
input_token_logprobs,
|
||||||
decode_token_logprobs,
|
output_token_logprobs,
|
||||||
) = self.backend.select(self, expr.choices, expr.temperature)
|
) = self.backend.select(self, expr.choices, expr.temperature)
|
||||||
if expr.name is not None:
|
if expr.name is not None:
|
||||||
name = expr.name
|
name = expr.name
|
||||||
self.variables[name] = decision
|
self.variables[name] = decision
|
||||||
self.meta_info[name] = {
|
self.meta_info[name] = {
|
||||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||||
"prefill_token_logprobs": prefill_token_logprobs,
|
"input_token_logprobs": input_token_logprobs,
|
||||||
"decode_token_logprobs": decode_token_logprobs,
|
"output_token_logprobs": output_token_logprobs,
|
||||||
}
|
}
|
||||||
self.variable_event[name].set()
|
self.variable_event[name].set()
|
||||||
self.text_ += decision
|
self.text_ += decision
|
||||||
|
|||||||
@@ -22,13 +22,13 @@ class LogitProcessorOutput:
|
|||||||
|
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
# The normlaized logprobs of prompts. shape: [#seq]
|
||||||
normalized_prompt_logprobs: torch.Tensor
|
normalized_prompt_logprobs: torch.Tensor
|
||||||
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||||
prefill_token_logprobs: torch.Tensor
|
input_token_logprobs: torch.Tensor
|
||||||
|
|
||||||
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
prefill_top_logprobs: List
|
input_top_logprobs: List
|
||||||
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||||
decode_top_logprobs: List
|
output_top_logprobs: List
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
def _get_normalized_prompt_logprobs(
|
def _get_normalized_prompt_logprobs(
|
||||||
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||||
):
|
):
|
||||||
logprobs_cumsum = torch.cumsum(
|
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
start = logits_metadata.extend_start_loc.clone()
|
start = logits_metadata.extend_start_loc.clone()
|
||||||
end = start + logits_metadata.extend_seq_lens - 2
|
end = start + logits_metadata.extend_seq_lens - 2
|
||||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||||
sum_logp = (
|
sum_logp = (
|
||||||
logprobs_cumsum[end]
|
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||||
- logprobs_cumsum[start]
|
|
||||||
+ prefill_token_logprobs[start]
|
|
||||||
)
|
)
|
||||||
normalized_prompt_logprobs = sum_logp / (
|
normalized_prompt_logprobs = sum_logp / (
|
||||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
|
|||||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||||
# TODO: vectorize the code below
|
# TODO: vectorize the code below
|
||||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
decode_top_logprobs = []
|
output_top_logprobs = []
|
||||||
for i in range(all_logprobs.shape[0]):
|
for i in range(all_logprobs.shape[0]):
|
||||||
k = logits_metadata.top_logprobs_nums[i]
|
k = logits_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[i].topk(k)
|
t = all_logprobs[i].topk(k)
|
||||||
v_cpu = t.values.tolist()
|
v_cpu = t.values.tolist()
|
||||||
p_cpu = t.indices.tolist()
|
p_cpu = t.indices.tolist()
|
||||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||||
return None, decode_top_logprobs
|
return None, output_top_logprobs
|
||||||
else:
|
else:
|
||||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
input_top_logprobs, output_top_logprobs = [], []
|
||||||
pt = 0
|
pt = 0
|
||||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||||
if extend_seq_len == 0:
|
if extend_seq_len == 0:
|
||||||
prefill_top_logprobs.append([])
|
input_top_logprobs.append([])
|
||||||
decode_top_logprobs.append([])
|
output_top_logprobs.append([])
|
||||||
continue
|
continue
|
||||||
k = logits_metadata.top_logprobs_nums[i]
|
k = logits_metadata.top_logprobs_nums[i]
|
||||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||||
vs_cpu = t.values.tolist()
|
vs_cpu = t.values.tolist()
|
||||||
ps_cpu = t.indices.tolist()
|
ps_cpu = t.indices.tolist()
|
||||||
prefill_top_logprobs.append(
|
input_top_logprobs.append(
|
||||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||||
)
|
)
|
||||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||||
pt += extend_seq_len
|
pt += extend_seq_len
|
||||||
|
|
||||||
return prefill_top_logprobs, decode_top_logprobs
|
return input_top_logprobs, output_top_logprobs
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=None,
|
next_token_logprobs=None,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
prefill_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
decode_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
|
|||||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||||
)
|
)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
decode_top_logprobs = self.get_top_logprobs(
|
output_top_logprobs = self.get_top_logprobs(
|
||||||
last_logprobs, logits_metadata
|
last_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
else:
|
else:
|
||||||
decode_top_logprobs = None
|
output_top_logprobs = None
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
prefill_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
decode_top_logprobs=decode_top_logprobs,
|
output_top_logprobs=output_top_logprobs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
all_logits = torch.matmul(hidden_states, weight.T)
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
|
|||||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||||
)
|
)
|
||||||
if return_top_logprob:
|
if return_top_logprob:
|
||||||
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
|
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||||
all_logprobs, logits_metadata
|
all_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_top_logprobs = decode_top_logprobs = None
|
input_top_logprobs = output_top_logprobs = None
|
||||||
|
|
||||||
last_logprobs = all_logprobs[last_index]
|
last_logprobs = all_logprobs[last_index]
|
||||||
|
|
||||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||||
prefill_token_logprobs = all_logprobs[
|
input_token_logprobs = all_logprobs[
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
]
|
]
|
||||||
|
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
prefill_token_logprobs, logits_metadata
|
input_token_logprobs, logits_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
return LogitProcessorOutput(
|
return LogitProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
next_token_logprobs=last_logprobs,
|
next_token_logprobs=last_logprobs,
|
||||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||||
prefill_token_logprobs=prefill_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
prefill_top_logprobs=prefill_top_logprobs,
|
input_top_logprobs=input_top_logprobs,
|
||||||
decode_top_logprobs=decode_top_logprobs,
|
output_top_logprobs=output_top_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -226,9 +226,9 @@ class CudaGraphRunner:
|
|||||||
next_token_logits=output.next_token_logits[:raw_bs],
|
next_token_logits=output.next_token_logits[:raw_bs],
|
||||||
next_token_logprobs=None,
|
next_token_logprobs=None,
|
||||||
normalized_prompt_logprobs=None,
|
normalized_prompt_logprobs=None,
|
||||||
prefill_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
decode_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Extract logprobs
|
# Extract logprobs
|
||||||
@@ -242,7 +242,7 @@ class CudaGraphRunner:
|
|||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
top_logprobs_nums=batch.top_logprobs_nums,
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
)
|
)
|
||||||
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
|
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||||
output.next_token_logprobs, logits_metadata
|
output.next_token_logprobs, logits_metadata
|
||||||
)[1]
|
)[1]
|
||||||
|
|
||||||
|
|||||||
@@ -124,10 +124,10 @@ class Req:
|
|||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.top_logprobs_num = 0
|
self.top_logprobs_num = 0
|
||||||
self.normalized_prompt_logprob = None
|
self.normalized_prompt_logprob = None
|
||||||
self.prefill_token_logprobs = None
|
self.input_token_logprobs = None
|
||||||
self.prefill_top_logprobs = None
|
self.input_top_logprobs = None
|
||||||
self.decode_token_logprobs = []
|
self.output_token_logprobs = []
|
||||||
self.decode_top_logprobs = []
|
self.output_top_logprobs = []
|
||||||
# The tokens is prefilled but need to be considered as decode tokens
|
# The tokens is prefilled but need to be considered as decode tokens
|
||||||
# and should be updated for the decode logprobs
|
# and should be updated for the decode logprobs
|
||||||
self.last_update_decode_tokens = 0
|
self.last_update_decode_tokens = 0
|
||||||
@@ -244,8 +244,8 @@ class Req:
|
|||||||
k = k + 1
|
k = k + 1
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
self.output_token_logprobs = self.output_token_logprobs[:k]
|
||||||
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
self.output_top_logprobs = self.output_top_logprobs[:k]
|
||||||
self.logprob_start_len = prompt_tokens + k
|
self.logprob_start_len = prompt_tokens + k
|
||||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||||
|
|
||||||
|
|||||||
@@ -455,7 +455,7 @@ class ModelTpServer:
|
|||||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
].tolist()
|
].tolist()
|
||||||
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||||
output.normalized_prompt_logprobs = (
|
output.normalized_prompt_logprobs = (
|
||||||
output.normalized_prompt_logprobs.tolist()
|
output.normalized_prompt_logprobs.tolist()
|
||||||
)
|
)
|
||||||
@@ -481,24 +481,24 @@ class ModelTpServer:
|
|||||||
if req.normalized_prompt_logprob is None:
|
if req.normalized_prompt_logprob is None:
|
||||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||||
|
|
||||||
if req.prefill_token_logprobs is None:
|
if req.input_token_logprobs is None:
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
req.prefill_token_logprobs = list(
|
req.input_token_logprobs = list(
|
||||||
zip(
|
zip(
|
||||||
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||||
req.input_ids[-req.extend_input_len + 1 :],
|
req.input_ids[-req.extend_input_len + 1 :],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if req.logprob_start_len == 0:
|
if req.logprob_start_len == 0:
|
||||||
req.prefill_token_logprobs = [
|
req.input_token_logprobs = [
|
||||||
(None, req.input_ids[0])
|
(None, req.input_ids[0])
|
||||||
] + req.prefill_token_logprobs
|
] + req.input_token_logprobs
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
if req.last_update_decode_tokens != 0:
|
||||||
req.decode_token_logprobs.extend(
|
req.output_token_logprobs.extend(
|
||||||
list(
|
list(
|
||||||
zip(
|
zip(
|
||||||
output.prefill_token_logprobs[
|
output.input_token_logprobs[
|
||||||
pt
|
pt
|
||||||
+ req.extend_input_len
|
+ req.extend_input_len
|
||||||
- req.last_update_decode_tokens : pt
|
- req.last_update_decode_tokens : pt
|
||||||
@@ -510,21 +510,21 @@ class ModelTpServer:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
req.decode_token_logprobs.append(
|
req.output_token_logprobs.append(
|
||||||
(output.next_token_logprobs[i], next_token_ids[i])
|
(output.next_token_logprobs[i], next_token_ids[i])
|
||||||
)
|
)
|
||||||
|
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
if req.prefill_top_logprobs is None:
|
if req.input_top_logprobs is None:
|
||||||
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||||
if req.logprob_start_len == 0:
|
if req.logprob_start_len == 0:
|
||||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
req.input_top_logprobs = [None] + req.input_top_logprobs
|
||||||
|
|
||||||
if req.last_update_decode_tokens != 0:
|
if req.last_update_decode_tokens != 0:
|
||||||
req.decode_top_logprobs.extend(
|
req.output_top_logprobs.extend(
|
||||||
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||||
)
|
)
|
||||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||||
|
|
||||||
def cache_filled_batch(self, batch: Batch):
|
def cache_filled_batch(self, batch: Batch):
|
||||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||||
@@ -589,11 +589,11 @@ class ModelTpServer:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.decode_token_logprobs.append(
|
req.output_token_logprobs.append(
|
||||||
(next_token_logprobs[i], next_token_id)
|
(next_token_logprobs[i], next_token_id)
|
||||||
)
|
)
|
||||||
if req.top_logprobs_num > 0:
|
if req.top_logprobs_num > 0:
|
||||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
@@ -645,16 +645,16 @@ class ModelTpServer:
|
|||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
(
|
(
|
||||||
meta_info["prefill_token_logprobs"],
|
meta_info["input_token_logprobs"],
|
||||||
meta_info["decode_token_logprobs"],
|
meta_info["output_token_logprobs"],
|
||||||
meta_info["prefill_top_logprobs"],
|
meta_info["input_top_logprobs"],
|
||||||
meta_info["decode_top_logprobs"],
|
meta_info["output_top_logprobs"],
|
||||||
meta_info["normalized_prompt_logprob"],
|
meta_info["normalized_prompt_logprob"],
|
||||||
) = (
|
) = (
|
||||||
req.prefill_token_logprobs,
|
req.input_token_logprobs,
|
||||||
req.decode_token_logprobs,
|
req.output_token_logprobs,
|
||||||
req.prefill_top_logprobs,
|
req.input_top_logprobs,
|
||||||
req.decode_top_logprobs,
|
req.output_top_logprobs,
|
||||||
req.normalized_prompt_logprob,
|
req.normalized_prompt_logprob,
|
||||||
)
|
)
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class GenerateReqInput:
|
|||||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||||
# See also python/sglang/srt/utils.py:load_image.
|
# See also python/sglang/srt/utils.py:load_image.
|
||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
# The sampling_params.
|
# The sampling_params. See descriptions below.
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
# The request id.
|
# The request id.
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
@@ -30,7 +30,7 @@ class GenerateReqInput:
|
|||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
# The number of top logprobs to return.
|
# The number of top logprobs to return.
|
||||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
# Whether to detokenize tokens in logprobs.
|
# Whether to detokenize tokens in text in the returned logprobs.
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output.
|
# Whether to stream output.
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|||||||
@@ -448,23 +448,23 @@ class TokenizerManager:
|
|||||||
return_text_in_logprobs: bool,
|
return_text_in_logprobs: bool,
|
||||||
):
|
):
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
|
|
||||||
if top_logprobs_num > 0:
|
if top_logprobs_num > 0:
|
||||||
ret["meta_info"]["prefill_top_logprobs"] = (
|
ret["meta_info"]["input_top_logprobs"] = (
|
||||||
self.detokenize_top_logprobs_tokens(
|
self.detokenize_top_logprobs_tokens(
|
||||||
ret["meta_info"]["prefill_top_logprobs"],
|
ret["meta_info"]["input_top_logprobs"],
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
ret["meta_info"]["decode_top_logprobs"] = (
|
ret["meta_info"]["output_top_logprobs"] = (
|
||||||
self.detokenize_top_logprobs_tokens(
|
self.detokenize_top_logprobs_tokens(
|
||||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
|
|||||||
next_token_logits=scores,
|
next_token_logits=scores,
|
||||||
next_token_logprobs=scores,
|
next_token_logprobs=scores,
|
||||||
normalized_prompt_logprobs=scores,
|
normalized_prompt_logprobs=scores,
|
||||||
prefill_token_logprobs=torch.ones_like(input_ids),
|
input_token_logprobs=torch.ones_like(input_ids),
|
||||||
prefill_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
decode_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
|||||||
@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
# The first chunk and echo is enabled.
|
# The first chunk and echo is enabled.
|
||||||
if not stream_buffer and request.echo:
|
if not stream_buffer and request.echo:
|
||||||
prefill_token_logprobs = content["meta_info"][
|
input_token_logprobs = content["meta_info"][
|
||||||
"prefill_token_logprobs"
|
"input_token_logprobs"
|
||||||
]
|
]
|
||||||
prefill_top_logprobs = content["meta_info"][
|
input_top_logprobs = content["meta_info"][
|
||||||
"prefill_top_logprobs"
|
"input_top_logprobs"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
prefill_token_logprobs = None
|
input_token_logprobs = None
|
||||||
prefill_top_logprobs = None
|
input_top_logprobs = None
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
logprobs = to_openai_style_logprobs(
|
||||||
prefill_token_logprobs=prefill_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
prefill_top_logprobs=prefill_top_logprobs,
|
input_top_logprobs=input_top_logprobs,
|
||||||
decode_token_logprobs=content["meta_info"][
|
output_token_logprobs=content["meta_info"][
|
||||||
"decode_token_logprobs"
|
"output_token_logprobs"
|
||||||
][n_prev_token:],
|
][n_prev_token:],
|
||||||
decode_top_logprobs=content["meta_info"][
|
output_top_logprobs=content["meta_info"][
|
||||||
"decode_top_logprobs"
|
"output_top_logprobs"
|
||||||
][n_prev_token:],
|
][n_prev_token:],
|
||||||
)
|
)
|
||||||
|
|
||||||
n_prev_token = len(
|
n_prev_token = len(
|
||||||
content["meta_info"]["decode_token_logprobs"]
|
content["meta_info"]["output_token_logprobs"]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
if request.logprobs:
|
if request.logprobs:
|
||||||
if request.echo:
|
if request.echo:
|
||||||
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
|
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||||
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
|
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||||
else:
|
else:
|
||||||
prefill_token_logprobs = None
|
input_token_logprobs = None
|
||||||
prefill_top_logprobs = None
|
input_top_logprobs = None
|
||||||
|
|
||||||
logprobs = to_openai_style_logprobs(
|
logprobs = to_openai_style_logprobs(
|
||||||
prefill_token_logprobs=prefill_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
prefill_top_logprobs=prefill_top_logprobs,
|
input_top_logprobs=input_top_logprobs,
|
||||||
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
|
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||||
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
|
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def to_openai_style_logprobs(
|
def to_openai_style_logprobs(
|
||||||
prefill_token_logprobs=None,
|
input_token_logprobs=None,
|
||||||
decode_token_logprobs=None,
|
output_token_logprobs=None,
|
||||||
prefill_top_logprobs=None,
|
input_top_logprobs=None,
|
||||||
decode_top_logprobs=None,
|
output_top_logprobs=None,
|
||||||
):
|
):
|
||||||
ret_logprobs = LogProbs()
|
ret_logprobs = LogProbs()
|
||||||
|
|
||||||
@@ -425,13 +425,13 @@ def to_openai_style_logprobs(
|
|||||||
else:
|
else:
|
||||||
ret_logprobs.top_logprobs.append(None)
|
ret_logprobs.top_logprobs.append(None)
|
||||||
|
|
||||||
if prefill_token_logprobs is not None:
|
if input_token_logprobs is not None:
|
||||||
append_token_logprobs(prefill_token_logprobs)
|
append_token_logprobs(input_token_logprobs)
|
||||||
if decode_token_logprobs is not None:
|
if output_token_logprobs is not None:
|
||||||
append_token_logprobs(decode_token_logprobs)
|
append_token_logprobs(output_token_logprobs)
|
||||||
if prefill_top_logprobs is not None:
|
if input_top_logprobs is not None:
|
||||||
append_top_logprobs(prefill_top_logprobs)
|
append_top_logprobs(input_top_logprobs)
|
||||||
if decode_top_logprobs is not None:
|
if output_top_logprobs is not None:
|
||||||
append_top_logprobs(decode_top_logprobs)
|
append_top_logprobs(output_top_logprobs)
|
||||||
|
|
||||||
return ret_logprobs
|
return ret_logprobs
|
||||||
|
|||||||
@@ -13,14 +13,15 @@ import json
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_decode(url, return_logprob, top_logprobs_num, return_text):
|
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
"text": "The capital of France is",
|
"text": "The capital of France is",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"temperature": 0,
|
"temperature": 0 if n == 1 else 0.5,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
|
"n": n,
|
||||||
},
|
},
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
@@ -41,8 +42,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
test_decode(url, False, 0, False)
|
test_decode(url)
|
||||||
test_decode(url, True, 0, False)
|
test_decode(url, n=3)
|
||||||
test_decode(url, True, 0, True)
|
|
||||||
test_decode(url, True, 3, False)
|
for top_logprobs_num in [0, 3]:
|
||||||
test_decode(url, True, 3, True)
|
for return_text in [True, False]:
|
||||||
|
test_decode(
|
||||||
|
url,
|
||||||
|
return_logprob=True,
|
||||||
|
top_logprobs_num=top_logprobs_num,
|
||||||
|
return_text=return_text,
|
||||||
|
)
|
||||||
|
|||||||
@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
|
|||||||
data = json.loads(chunk[5:].strip("\n"))
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
assert data["meta_info"]["prefill_token_logprobs"] is not None
|
assert data["meta_info"]["input_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["decode_token_logprobs"] is not None
|
assert data["meta_info"]["output_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||||
for logprob, token_id, token_text in data["meta_info"][
|
for logprob, token_id, token_text in data["meta_info"][
|
||||||
"decode_token_logprobs"
|
"output_token_logprobs"
|
||||||
][prev:]:
|
][prev:]:
|
||||||
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
||||||
prev = len(data["meta_info"]["decode_token_logprobs"])
|
prev = len(data["meta_info"]["output_token_logprobs"])
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
print(output[prev:], end="", flush=True)
|
print(output[prev:], end="", flush=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user