[Fix] Fix logprob and normalized_logprob (#1428)
This commit is contained in:
@@ -11,16 +11,18 @@ suites = {
|
||||
"test_chunked_prefill.py",
|
||||
"test_embedding_openai_server.py",
|
||||
"test_eval_accuracy_mini.py",
|
||||
"test_json_constrained.py",
|
||||
"test_large_max_new_tokens.py",
|
||||
"test_openai_server.py",
|
||||
"test_json_constrained.py",
|
||||
"test_skip_tokenizer_init.py",
|
||||
"test_torch_compile.py",
|
||||
"test_triton_attn_backend.py",
|
||||
"test_pytorch_sampling_backend.py",
|
||||
"test_server_args.py",
|
||||
"test_skip_tokenizer_init.py",
|
||||
"test_srt_endpoint.py",
|
||||
"test_torch_compile.py",
|
||||
"test_torchao.py",
|
||||
"test_triton_attn_backend.py",
|
||||
"test_update_weights.py",
|
||||
"test_vision_openai_server.py",
|
||||
"test_server_args.py",
|
||||
],
|
||||
"sampling/penaltylib": glob.glob(
|
||||
"sampling/penaltylib/**/test_*.py", recursive=True
|
||||
|
||||
@@ -33,13 +33,13 @@ class TestChunkedPrefill(unittest.TestCase):
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
eval_name="mmlu",
|
||||
num_examples=32,
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
)
|
||||
|
||||
try:
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.6
|
||||
assert metrics["score"] >= 0.65
|
||||
finally:
|
||||
kill_child_process(process.pid)
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ class TestJSONConstrained(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.json_schema = json.dumps(
|
||||
{
|
||||
"type": "object",
|
||||
@@ -28,16 +27,13 @@ class TestJSONConstrained(unittest.TestCase):
|
||||
"required": ["name", "population"],
|
||||
}
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
|
||||
)
|
||||
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1):
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
@@ -54,7 +50,6 @@ class TestJSONConstrained(unittest.TestCase):
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
print("=" * 100)
|
||||
@@ -69,7 +64,7 @@ class TestJSONConstrained(unittest.TestCase):
|
||||
self.run_decode()
|
||||
|
||||
def test_json_openai(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1")
|
||||
client = openai.Client(api_key="EMPTY", base_url=f"{self.base_url}/v1")
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
|
||||
@@ -75,11 +75,11 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert isinstance(response.choices[0].logprobs.top_logprobs[1], dict)
|
||||
ret_num_top_logprobs = len(response.choices[0].logprobs.top_logprobs[1])
|
||||
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
|
||||
assert ret_num_top_logprobs > 0
|
||||
assert response.choices[0].logprobs.token_logprobs[0] != None
|
||||
|
||||
assert response.choices[0].logprobs.token_logprobs[0]
|
||||
|
||||
assert response.id
|
||||
assert response.created
|
||||
@@ -143,7 +143,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some out_put id maps to the same output token and duplicate in the map
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
|
||||
@@ -479,6 +479,22 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert isinstance(js_obj["name"], str)
|
||||
assert isinstance(js_obj["population"], int)
|
||||
|
||||
def test_penalty(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful AI assistant"},
|
||||
{"role": "user", "content": "Introduce the capital of France."},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
frequency_penalty=1.0,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_simple_decode
|
||||
"""
|
||||
|
||||
import json
|
||||
import unittest
|
||||
|
||||
@@ -39,7 +43,7 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
"text": "The capital of France is",
|
||||
"sampling_params": {
|
||||
"temperature": 0 if n == 1 else 0.5,
|
||||
"max_new_tokens": 32,
|
||||
"max_new_tokens": 16,
|
||||
"n": n,
|
||||
},
|
||||
"stream": stream,
|
||||
@@ -56,7 +60,8 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
for line in response.iter_lines():
|
||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||
response_json.append(json.loads(line[6:]))
|
||||
print(json.dumps(response_json))
|
||||
|
||||
print(json.dumps(response_json, indent=2))
|
||||
print("=" * 100)
|
||||
|
||||
def test_simple_decode(self):
|
||||
@@ -69,13 +74,50 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
self.run_decode(n=3, stream=True)
|
||||
|
||||
def test_logprob(self):
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
self.run_decode(
|
||||
return_logprob=True,
|
||||
top_logprobs_num=top_logprobs_num,
|
||||
return_text=return_text,
|
||||
)
|
||||
self.run_decode(
|
||||
return_logprob=True,
|
||||
top_logprobs_num=5,
|
||||
return_text=True,
|
||||
)
|
||||
|
||||
def test_logprob_start_len(self):
|
||||
logprob_start_len = 4
|
||||
new_tokens = 4
|
||||
prompts = [
|
||||
"I have a very good idea on",
|
||||
"Today is a sunndy day and",
|
||||
]
|
||||
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"text": prompts,
|
||||
"sampling_params": {
|
||||
"temperature": 0,
|
||||
"max_new_tokens": new_tokens,
|
||||
},
|
||||
"return_logprob": True,
|
||||
"top_logprobs_num": 5,
|
||||
"return_text_in_logprobs": True,
|
||||
"logprob_start_len": logprob_start_len,
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
print(json.dumps(response_json, indent=2))
|
||||
|
||||
for i, res in enumerate(response_json):
|
||||
assert res["meta_info"]["prompt_tokens"] == logprob_start_len + 1 + len(
|
||||
res["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
assert prompts[i].endswith(
|
||||
"".join([x[-1] for x in res["meta_info"]["input_token_logprobs"]])
|
||||
)
|
||||
|
||||
assert res["meta_info"]["completion_tokens"] == new_tokens
|
||||
assert len(res["meta_info"]["output_token_logprobs"]) == new_tokens
|
||||
res["text"] == "".join(
|
||||
[x[-1] for x in res["meta_info"]["output_token_logprobs"]]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -39,7 +39,7 @@ class TestTorchCompile(unittest.TestCase):
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
assert metrics["score"] >= 0.65
|
||||
assert metrics["score"] >= 0.60
|
||||
|
||||
def run_decode(self, max_new_tokens):
|
||||
response = requests.post(
|
||||
|
||||
@@ -127,7 +127,6 @@ class TestExtendAttention(unittest.TestCase):
|
||||
|
||||
def _test_context_attention_once(self, head_dim):
|
||||
# Set up a simple test case
|
||||
batch_size = 2
|
||||
num_heads = 4
|
||||
seq_lens = [8, 12]
|
||||
max_seq_len = max(seq_lens)
|
||||
|
||||
Reference in New Issue
Block a user