[Fix] Fix logprob and normalized_logprob (#1428)

This commit is contained in:
Lianmin Zheng
2024-09-15 06:36:06 -07:00
committed by GitHub
parent 282681b8a1
commit 9ba1f09760
22 changed files with 314 additions and 215 deletions

View File

@@ -11,16 +11,18 @@ suites = {
"test_chunked_prefill.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_json_constrained.py",
"test_large_max_new_tokens.py",
"test_openai_server.py",
"test_json_constrained.py",
"test_skip_tokenizer_init.py",
"test_torch_compile.py",
"test_triton_attn_backend.py",
"test_pytorch_sampling_backend.py",
"test_server_args.py",
"test_skip_tokenizer_init.py",
"test_srt_endpoint.py",
"test_torch_compile.py",
"test_torchao.py",
"test_triton_attn_backend.py",
"test_update_weights.py",
"test_vision_openai_server.py",
"test_server_args.py",
],
"sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True

View File

@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase):
base_url=base_url,
model=model,
eval_name="mmlu",
num_examples=32,
num_examples=64,
num_threads=32,
)
try:
metrics = run_eval(args)
assert metrics["score"] >= 0.6
assert metrics["score"] >= 0.65
finally:
kill_child_process(process.pid)

View File

@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase):
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-123456"
cls.json_schema = json.dumps(
{
"type": "object",
@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase):
"required": ["name", "population"],
}
)
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid)
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
headers = {"Authorization": f"Bearer {self.api_key}"}
response = requests.post(
self.base_url + "/generate",
json={
@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase):
"top_logprobs_num": top_logprobs_num,
"logprob_start_len": 0,
},
headers=headers,
)
print(json.dumps(response.json()))
print("=" * 100)
@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase):
self.run_decode()
def test_json_openai(self):
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
response = client.chat.completions.create(
model=self.model,

View File

@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
assert response.choices[0].logprobs.token_logprobs[0] != None
assert response.choices[0].logprobs.token_logprobs[0]
assert response.id
assert response.created
@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase):
ret_num_top_logprobs = len(
response.choices[0].logprobs.top_logprobs[0]
)
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase):
assert isinstance(js_obj["name"], str)
assert isinstance(js_obj["population"], int)
def test_penalty(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "You are a helpful AI assistant"},
{"role": "user", "content": "Introduce the capital of France."},
],
temperature=0,
max_tokens=32,
frequency_penalty=1.0,
)
text = response.choices[0].message.content
assert isinstance(text, str)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,3 +1,7 @@
"""
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
"""
import json
import unittest
@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
"text": "The capital of France is",
"sampling_params": {
"temperature": 0 if n == 1 else 0.5,
"max_new_tokens": 32,
"max_new_tokens": 16,
"n": n,
},
"stream": stream,
@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print(json.dumps(response_json, indent=2))
print("=" * 100)
def test_simple_decode(self):
@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
self.run_decode(n=3, stream=True)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [True, False]:
self.run_decode(
return_logprob=True,
top_logprobs_num=top_logprobs_num,
return_text=return_text,
)
self.run_decode(
return_logprob=True,
top_logprobs_num=5,
return_text=True,
)
def test_logprob_start_len(self):
logprob_start_len = 4
new_tokens = 4
prompts = [
"I have a very good idea on",
"Today is a sunndy day and",
]
response = requests.post(
self.base_url + "/generate",
json={
"text": prompts,
"sampling_params": {
"temperature": 0,
"max_new_tokens": new_tokens,
},
"return_logprob": True,
"top_logprobs_num": 5,
"return_text_in_logprobs": True,
"logprob_start_len": logprob_start_len,
},
)
response_json = response.json()
print(json.dumps(response_json, indent=2))
for i, res in enumerate(response_json):
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
res["meta_info"]["input_token_logprobs"]
)
assert prompts[i].endswith(
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
)
assert res["meta_info"]["completion_tokens"] == new_tokens
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
res["text"] == "".join(
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
)
if __name__ == "__main__":

View File

@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.65
assert metrics["score"] >= 0.60
def run_decode(self, max_new_tokens):
response = requests.post(

View File

@@ -127,7 +127,6 @@ class TestExtendAttention(unittest.TestCase):
def _test_context_attention_once(self, head_dim):
# Set up a simple test case
batch_size = 2
num_heads = 4
seq_lens = [8, 12]
max_seq_len = max(seq_lens)