feat: allow streaming for multi-prompt and/or parallel sampling (#1134)
This commit is contained in:
@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_completion_stream(self, echo, logprobs, token_input):
|
||||
def run_completion_stream(
|
||||
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
|
||||
):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
prompt = "The capital of France is"
|
||||
if token_input:
|
||||
prompt_arg = self.tokenizer.encode(prompt)
|
||||
prompt_input = self.tokenizer.encode(prompt)
|
||||
num_prompt_tokens = len(prompt_input)
|
||||
else:
|
||||
prompt_arg = prompt
|
||||
prompt_input = prompt
|
||||
num_prompt_tokens = len(self.tokenizer.encode(prompt))
|
||||
|
||||
if use_list_input:
|
||||
prompt_arg = [prompt_input, prompt_input]
|
||||
num_choices = len(prompt_arg)
|
||||
num_prompt_tokens *= 2
|
||||
else:
|
||||
prompt_arg = prompt_input
|
||||
num_choices = 1
|
||||
|
||||
generator = client.completions.create(
|
||||
model=self.model,
|
||||
prompt=prompt_arg,
|
||||
@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
first = True
|
||||
is_firsts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert usage.completion_tokens > 0
|
||||
assert usage.total_tokens > 0
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
is_first = is_firsts.get(index, True)
|
||||
|
||||
if logprobs:
|
||||
assert response.choices[0].logprobs
|
||||
assert isinstance(response.choices[0].logprobs.tokens[0], str)
|
||||
if not (first and echo):
|
||||
if not (is_first and echo):
|
||||
assert isinstance(
|
||||
response.choices[0].logprobs.top_logprobs[0], dict
|
||||
)
|
||||
@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
|
||||
assert ret_num_top_logprobs > 0
|
||||
|
||||
if first:
|
||||
if is_first:
|
||||
if echo:
|
||||
assert response.choices[0].text.startswith(
|
||||
prompt
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
|
||||
first = False
|
||||
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
|
||||
is_firsts[index] = False
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in [i for i in range(parallel_sample_num * num_choices)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_chat_completion(self, logprobs, parallel_sample_num):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
response = client.chat.completions.create(
|
||||
@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
def run_chat_completion_stream(self, logprobs):
|
||||
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
generator = client.chat.completions.create(
|
||||
model=self.model,
|
||||
@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
top_logprobs=logprobs,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
n=parallel_sample_num,
|
||||
)
|
||||
|
||||
is_first = True
|
||||
is_firsts = {}
|
||||
for response in generator:
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert usage.total_tokens > 0
|
||||
continue
|
||||
|
||||
index = response.choices[0].index
|
||||
data = response.choices[0].delta
|
||||
|
||||
if is_first:
|
||||
data.role == "assistant"
|
||||
is_first = False
|
||||
if is_firsts.get(index, True):
|
||||
assert data.role == "assistant"
|
||||
is_firsts[index] = False
|
||||
continue
|
||||
|
||||
if logprobs:
|
||||
@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
assert response.id
|
||||
assert response.created
|
||||
|
||||
for index in [i for i in range(parallel_sample_num)]:
|
||||
assert not is_firsts.get(
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
if mode == "completion":
|
||||
@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
assert batch_job.status == "completed"
|
||||
assert (
|
||||
batch_job.status == "completed"
|
||||
), f"Batch job status is not completed: {batch_job.status}"
|
||||
assert batch_job.request_counts.completed == len(content)
|
||||
assert batch_job.request_counts.failed == 0
|
||||
assert batch_job.request_counts.total == len(content)
|
||||
@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
# parallel sampling adn list input are not supported in streaming mode
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(echo, logprobs, token_input)
|
||||
for use_list_input in [True, False]:
|
||||
for parallel_sample_num in [1, 2]:
|
||||
for token_input in [False, True]:
|
||||
self.run_completion_stream(
|
||||
echo,
|
||||
logprobs,
|
||||
use_list_input,
|
||||
parallel_sample_num,
|
||||
token_input,
|
||||
)
|
||||
|
||||
def test_chat_completion(self):
|
||||
for logprobs in [None, 5]:
|
||||
@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
for logprobs in [None, 5]:
|
||||
self.run_chat_completion_stream(logprobs)
|
||||
for parallel_sample_num in [1, 2]:
|
||||
self.run_chat_completion_stream(logprobs, parallel_sample_num)
|
||||
|
||||
def test_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
|
||||
Reference in New Issue
Block a user