feat: allow streaming for multi-prompt and/or parallel sampling (#1134)
This commit is contained in:
@@ -23,7 +23,12 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
kill_child_process(cls.process.pid)
|
||||
|
||||
def run_decode(
|
||||
self, return_logprob=False, top_logprobs_num=0, return_text=False, n=1
|
||||
self,
|
||||
return_logprob=False,
|
||||
top_logprobs_num=0,
|
||||
return_text=False,
|
||||
n=1,
|
||||
stream=False,
|
||||
):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
@@ -34,14 +39,21 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
"max_new_tokens": 32,
|
||||
"n": n,
|
||||
},
|
||||
"stream": False,
|
||||
"stream": stream,
|
||||
"return_logprob": return_logprob,
|
||||
"top_logprobs_num": top_logprobs_num,
|
||||
"return_text_in_logprobs": return_text,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
print(json.dumps(response.json()))
|
||||
if not stream:
|
||||
response_json = response.json()
|
||||
else:
|
||||
response_json = []
|
||||
for line in response.iter_lines():
|
||||
if line.startswith(b"data: ") and line[6:] != b"[DONE]":
|
||||
response_json.append(json.loads(line[6:]))
|
||||
print(json.dumps(response_json))
|
||||
print("=" * 100)
|
||||
|
||||
def test_simple_decode(self):
|
||||
@@ -50,6 +62,9 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
def test_parallel_sample(self):
|
||||
self.run_decode(n=3)
|
||||
|
||||
def test_parallel_sample_stream(self):
|
||||
self.run_decode(n=3, stream=True)
|
||||
|
||||
def test_logprob(self):
|
||||
for top_logprobs_num in [0, 3]:
|
||||
for return_text in [True, False]:
|
||||
|
||||
Reference in New Issue
Block a user