[Fix] the issue of random order when input is a list (#1199)

This commit is contained in:
Ying Sheng
2024-08-24 21:43:03 -07:00
committed by GitHub
parent e61d13acdf
commit 1cb4da5c5f
4 changed files with 23 additions and 20 deletions

View File

@@ -437,13 +437,13 @@ class TokenizerManager:
is_stream = hasattr(obj, "stream") and obj.stream
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
output_list = []
output_list = [None] * len(tasks)
while tasks:
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
gen_index = tasks.index(task)
cur_index = tasks.index(task)
try:
result = task.result()
@@ -451,14 +451,14 @@ class TokenizerManager:
if is_stream:
yield result
else:
output_list.append(result)
output_list[result["index"]] = result
tasks[gen_index] = asyncio.create_task(
generators[gen_index].__anext__()
tasks[cur_index] = asyncio.create_task(
generators[cur_index].__anext__()
)
except StopAsyncIteration:
del generators[gen_index]
del tasks[gen_index]
del generators[cur_index]
del tasks[cur_index]
if not is_stream:
yield output_list

View File

@@ -591,7 +591,7 @@ class Runtime:
def generate(
self,
prompt: str,
prompt: Union[str, List[str]],
sampling_params: Optional[Dict] = None,
return_logprob: Optional[Union[List[bool], bool]] = False,
logprob_start_len: Optional[Union[List[int], int]] = None,
@@ -612,7 +612,7 @@ class Runtime:
def encode(
self,
prompt: str,
prompt: Union[str, List[str]],
):
json_data = {
"text": prompt,

View File

@@ -28,10 +28,10 @@ from sglang.srt.server import Runtime
DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# "The capital of France is",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
"AI is a field of computer science focused on",
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
]
dirpath = os.path.dirname(__file__)