Support decode token logprobs (#130)

This commit is contained in:
Cody Yu
2024-02-06 12:24:55 -08:00
committed by GitHub
parent ee1df26a77
commit a7334aeea1
10 changed files with 233 additions and 96 deletions

View File

@@ -9,10 +9,24 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
"""
import argparse
import time
import requests
def test_decode(url, return_logprob):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
"return_logprob": return_logprob,
"logprob_start_len": 0,
},
)
print(response.json())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
@@ -21,16 +35,5 @@ if __name__ == "__main__":
url = f"{args.host}:{args.port}"
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
},
# "return_logprob": True,
# "logprob_start_len": 0,
},
)
print(response.json())
test_decode(url, False)
test_decode(url, True)

View File

@@ -9,27 +9,20 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
import argparse
import json
import time
import requests
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
def test_decode_stream(url, return_logprob):
response = requests.post(
url + "/generate",
json={
"text": "The capital of France is",
"sampling_params": {
"temperature": 0,
"max_new_tokens": 512,
"max_new_tokens": 128,
},
"stream": True,
"return_logprob": return_logprob,
},
stream=True,
)
@@ -41,7 +34,29 @@ if __name__ == "__main__":
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
if return_logprob:
assert data["meta_info"]["prompt_logprob"] is not None
assert data["meta_info"]["token_logprob"] is not None
assert data["meta_info"]["normalized_prompt_logprob"] is not None
if prev == 0: # Skip prompt logprobs
prev = data["meta_info"]["prompt_tokens"]
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
print(f"{token_txt}\t{logprob}", flush=True)
prev = len(data["meta_info"]["token_logprob"])
else:
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
url = f"{args.host}:{args.port}"
test_decode_stream(url, False)
test_decode_stream(url, True)

View File

@@ -18,15 +18,26 @@ import argparse
import openai
def test_completion(args):
def test_completion(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
prompt="The capital of France is",
temperature=0,
max_tokens=32,
echo=echo,
logprobs=logprobs,
)
text = response.choices[0].text
print(response.choices[0].text)
if echo:
assert text.startswith("The capital of France is")
if logprobs:
assert response.choices[0].logprobs
if echo:
assert response.choices[0].logprobs.token_logprobs[0] == None
else:
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.id
assert response.created
assert response.usage.prompt_tokens > 0
@@ -34,7 +45,7 @@ def test_completion(args):
assert response.usage.total_tokens > 0
def test_completion_stream(args):
def test_completion_stream(args, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=args.base_url)
response = client.completions.create(
model="default",
@@ -42,9 +53,23 @@ def test_completion_stream(args):
temperature=0,
max_tokens=32,
stream=True,
echo=echo,
logprobs=logprobs,
)
first = True
for r in response:
print(r.choices[0].text, end="", flush=True)
if first:
if echo:
assert r.choices[0].text.startswith("The capital of France is")
first = False
if logprobs:
print(
f"{r.choices[0].text:12s}\t"
f"{r.choices[0].logprobs.token_logprobs}",
flush=True
)
else:
print(r.choices[0].text, end="", flush=True)
assert r.id
assert r.usage.prompt_tokens > 0
assert r.usage.completion_tokens > 0
@@ -135,8 +160,14 @@ if __name__ == "__main__":
)
args = parser.parse_args()
test_completion(args)
test_completion_stream(args)
test_completion(args, echo=False, logprobs=False)
test_completion(args, echo=True, logprobs=False)
test_completion(args, echo=False, logprobs=True)
test_completion(args, echo=True, logprobs=True)
test_completion_stream(args, echo=False, logprobs=False)
test_completion_stream(args, echo=True, logprobs=False)
test_completion_stream(args, echo=False, logprobs=True)
test_completion_stream(args, echo=True, logprobs=True)
test_chat_completion(args)
test_chat_completion_stream(args)
if args.test_image: