Open AI API hidden states (#6716)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
"""
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_batch
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion
|
||||
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_completion_stream
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion
|
||||
python3 -m unittest test_openai_server.TestOpenAIServer.test_chat_completion_stream
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -9,6 +11,7 @@ import re
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
|
||||
@@ -137,27 +140,29 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
assert response.choices[0].logprobs, f"no logprobs in response"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.tokens[0], str
|
||||
), f"{response.choices[0].logprobs.tokens[0]} is not a string"
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
)
|
||||
), f"top_logprobs was not a dictionary"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.top_logprobs[0]
|
||||
)
|
||||
# FIXME: Sometimes, some top_logprobs are missing in the return value. The reason is that some output id maps to the same output token and duplicate in the map
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
assert ret_num_top_logprobs > 0, f"ret_num_top_logprobs was 0"
|
||||
|
||||
if is_first:
|
||||
if echo:
|
||||
@@ -165,8 +170,8 @@ class TestOpenAIServer(CustomTestCase):
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.id, f"no id in response"
|
||||
assert response.created, f"no created in response"
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
@@ -231,27 +236,29 @@ class TestOpenAIServer(CustomTestCase):
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
assert usage.prompt_tokens > 0
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
assert usage.prompt_tokens > 0, f"usage.prompt_tokens was zero"
|
||||
assert usage.completion_tokens > 0, f"usage.completion_tokens was zero"
|
||||
assert usage.total_tokens > 0, f"usage.total_tokens was zero"
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_firsts.get(index, True):
|
||||
assert data.role == "assistant"
|
||||
assert (
|
||||
data.role == "assistant"
|
||||
), f"data.role was not 'assistant' for first chunk"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert response.choices[0].logprobs, f"logprobs was not returned"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs[0].token, str
|
||||
)
|
||||
), f"top_logprobs token was not a string"
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.content[0].top_logprobs, list
|
||||
)
|
||||
), f"top_logprobs was not a list"
|
||||
ret_num_top_logprobs = len(
|
||||
response.choices[0].logprobs.content[0].top_logprobs
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user