feat: allow streaming for multi-prompt and/or parallel sampling (#1134)

This commit is contained in:
Juwan Yoo
2024-08-20 08:06:55 -07:00
committed by GitHub
parent df191254ab
commit d8476818ef
4 changed files with 211 additions and 86 deletions

View File

@@ -85,13 +85,26 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs, token_input):
def run_completion_stream(
self, echo, logprobs, use_list_input, parallel_sample_num, token_input
):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if token_input:
prompt_arg = self.tokenizer.encode(prompt)
prompt_input = self.tokenizer.encode(prompt)
num_prompt_tokens = len(prompt_input)
else:
prompt_arg = prompt
prompt_input = prompt
num_prompt_tokens = len(self.tokenizer.encode(prompt))
if use_list_input:
prompt_arg = [prompt_input, prompt_input]
num_choices = len(prompt_arg)
num_prompt_tokens *= 2
else:
prompt_arg = prompt_input
num_choices = 1
generator = client.completions.create(
model=self.model,
prompt=prompt_arg,
@@ -101,9 +114,10 @@ class TestOpenAIServer(unittest.TestCase):
logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
first = True
is_firsts = {}
for response in generator:
usage = response.usage
if usage is not None:
@@ -111,10 +125,14 @@ class TestOpenAIServer(unittest.TestCase):
assert usage.completion_tokens > 0
assert usage.total_tokens > 0
continue
index = response.choices[0].index
is_first = is_firsts.get(index, True)
if logprobs:
assert response.choices[0].logprobs
assert isinstance(response.choices[0].logprobs.tokens[0], str)
if not (first and echo):
if not (is_first and echo):
assert isinstance(
response.choices[0].logprobs.top_logprobs[0], dict
)
@@ -125,15 +143,20 @@ class TestOpenAIServer(unittest.TestCase):
# assert ret_num_top_logprobs == logprobs, f"{ret_num_top_logprobs} vs {logprobs}"
assert ret_num_top_logprobs > 0
if first:
if is_first:
if echo:
assert response.choices[0].text.startswith(
prompt
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {first}"
first = False
), f"{response.choices[0].text} and all args {echo} {logprobs} {token_input} {is_first}"
is_firsts[index] = False
assert response.id
assert response.created
for index in [i for i in range(parallel_sample_num * num_choices)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_chat_completion(self, logprobs, parallel_sample_num):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
@@ -172,7 +195,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.completion_tokens > 0
assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs):
def run_chat_completion_stream(self, logprobs, parallel_sample_num=1):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
@@ -185,9 +208,10 @@ class TestOpenAIServer(unittest.TestCase):
top_logprobs=logprobs,
stream=True,
stream_options={"include_usage": True},
n=parallel_sample_num,
)
is_first = True
is_firsts = {}
for response in generator:
usage = response.usage
if usage is not None:
@@ -196,11 +220,12 @@ class TestOpenAIServer(unittest.TestCase):
assert usage.total_tokens > 0
continue
index = response.choices[0].index
data = response.choices[0].delta
if is_first:
data.role == "assistant"
is_first = False
if is_firsts.get(index, True):
assert data.role == "assistant"
is_firsts[index] = False
continue
if logprobs:
@@ -222,6 +247,11 @@ class TestOpenAIServer(unittest.TestCase):
assert response.id
assert response.created
for index in [i for i in range(parallel_sample_num)]:
assert not is_firsts.get(
index, True
), f"index {index} is not found in the response"
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
if mode == "completion":
@@ -320,7 +350,9 @@ class TestOpenAIServer(unittest.TestCase):
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
batch_job = client.batches.retrieve(batch_job.id)
assert batch_job.status == "completed"
assert (
batch_job.status == "completed"
), f"Batch job status is not completed: {batch_job.status}"
assert batch_job.request_counts.completed == len(content)
assert batch_job.request_counts.failed == 0
assert batch_job.request_counts.total == len(content)
@@ -353,8 +385,16 @@ class TestOpenAIServer(unittest.TestCase):
# parallel sampling adn list input are not supported in streaming mode
for echo in [False, True]:
for logprobs in [None, 5]:
for token_input in [False, True]:
self.run_completion_stream(echo, logprobs, token_input)
for use_list_input in [True, False]:
for parallel_sample_num in [1, 2]:
for token_input in [False, True]:
self.run_completion_stream(
echo,
logprobs,
use_list_input,
parallel_sample_num,
token_input,
)
def test_chat_completion(self):
for logprobs in [None, 5]:
@@ -363,7 +403,8 @@ class TestOpenAIServer(unittest.TestCase):
def test_chat_completion_stream(self):
for logprobs in [None, 5]:
self.run_chat_completion_stream(logprobs)
for parallel_sample_num in [1, 2]:
self.run_chat_completion_stream(logprobs, parallel_sample_num)
def test_batch(self):
for mode in ["completion", "chat"]:

View File

@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
self,
return_logprob=False,
top_logprobs_num=0,
return_text=False,
n=1,
stream=False,
):
response = requests.post(
self.base_url + "/generate",
@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
if not stream:
response_json = response.json()
else:
response_json = []
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print("=" * 100)
def test_simple_decode(self):
@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
def test_parallel_sample(self):
self.run_decode(n=3)
def test_parallel_sample_stream(self):
self.run_decode(n=3, stream=True)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [True, False]: