feat: allow streaming for multi-prompt and/or parallel sampling (#1134)

This commit is contained in:
Juwan Yoo
2024-08-20 08:06:55 -07:00
committed by GitHub
parent df191254ab
commit d8476818ef
4 changed files with 211 additions and 86 deletions

View File

@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
kill_child_process(cls.process.pid)
def run_decode(
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
self,
return_logprob=False,
top_logprobs_num=0,
return_text=False,
n=1,
stream=False,
):
response = requests.post(
self.base_url + "/generate",
@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
"max_new_tokens": 32,
"n": n,
},
"stream": False,
"stream": stream,
"return_logprob": return_logprob,
"top_logprobs_num": top_logprobs_num,
"return_text_in_logprobs": return_text,
"logprob_start_len": 0,
},
)
print(json.dumps(response.json()))
if not stream:
response_json = response.json()
else:
response_json = []
for line in response.iter_lines():
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
response_json.append(json.loads(line[6:]))
print(json.dumps(response_json))
print("=" * 100)
def test_simple_decode(self):
@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
def test_parallel_sample(self):
self.run_decode(n=3)
def test_parallel_sample_stream(self):
self.run_decode(n=3, stream=True)
def test_logprob(self):
for top_logprobs_num in [0, 3]:
for return_text in [True, False]: