Support decode token logprobs (#130)

This commit is contained in:
Cody Yu
2024-02-06 12:24:55 -08:00
committed by GitHub
parent ee1df26a77
commit a7334aeea1
10 changed files with 233 additions and 96 deletions

View File

@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import argparse
import json
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
def test_decode_stream(url, return_logprob):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 512,
"max_new_tokens": 128,
},
"stream": True,
"return_logprob": return_logprob,
},
stream=True,
)
@@ -41,7 +34,29 @@ if __name__ == "__main__":
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
if return_logprob:
assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
prev = len(data["meta_info"]["token_logprob"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False)
test_decode_stream(url, True)