Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -13,14 +13,15 @@ import json
import requests
def test_decode(url, return_logprob, top_logprobs_num, return_text):
def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"return_logprob": return_logprob,
@@ -41,8 +42,14 @@ if __name__ == "__main__":
url = f"{args.host}:{args.port}"
test_decode(url, False, 0, False)
test_decode(url, True, 0, False)
test_decode(url, True, 0, True)
test_decode(url, True, 3, False)
test_decode(url, True, 3, True)
test_decode(url)
test_decode(url, n=3)
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
test_decode(
url,
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)

View File

@@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
data = json.loads(chunk[5:].strip("\n"))
if return_logprob:
assert data["meta_info"]["prefill_token_logprobs"] is not None
assert data["meta_info"]["decode_token_logprobs"] is not None
assert data["meta_info"]["input_token_logprobs"] is not None
assert data["meta_info"]["output_token_logprobs"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
for logprob, token_id, token_text in data["meta_info"][
"decode_token_logprobs"
"output_token_logprobs"
][prev:]:
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
prev = len(data["meta_info"]["decode_token_logprobs"])
prev = len(data["meta_info"]["output_token_logprobs"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)