feat: allow streaming for multi-prompt and/or parallel sampling (#1134)

This commit is contained in:
Juwan Yoo
2024-08-20 08:06:55 -07:00
committed by GitHub
parent df191254ab
commit d8476818ef
4 changed files with 211 additions and 86 deletions

View File

@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
request_data = json.loads(line)
file_request_list.append(request_data)
body = request_data["body"]
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
if body.get("stream", False):
raise ValueError("Streaming requests are not supported in batch mode")
if end_point == "/v1/chat/completions":
all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions":
@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if adapted_request.stream:
async def generate_stream_resp():
stream_buffer = ""
n_prev_token = 0
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
index = content["index"]
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
text = content["text"]
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
if not stream_buffer: # The first chunk
if request.echo:
if isinstance(request.prompt, str):
# for the case of single str prompts
prompts = request.prompt
elif isinstance(request.prompt, list) and isinstance(
request.prompt[0], int
):
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt, list):
if isinstance(request.prompt[0], str):
# for the case of multiple str prompts
prompts = request.prompt[index // request.n]
elif isinstance(request.prompt[0], int):
# for the case of single token ids prompt
prompts = tokenizer_manager.tokenizer.decode(
request.prompt, skip_special_tokens=True
)
elif isinstance(request.prompt[0], list) and isinstance(
request.prompt[0][0], int
):
# for the case of multiple token ids prompts
prompts = tokenizer_manager.tokenizer.decode(
request.prompt[index // request.n],
skip_special_tokens=True,
)
# Prepend prompt in response text.
text = prompts + text
@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
choice_data = CompletionResponseStreamChoice(
index=0,
index=index,
text=delta,
logprobs=logprobs,
finish_reason=format_finish_reason(
@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data],
model=request.model,
)
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
final_usage_chunk = CompletionStreamResponse(
@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
if adapted_request.stream:
async def generate_stream_resp():
is_first = True
stream_buffer = ""
n_prev_token = 0
is_firsts = {}
stream_buffers = {}
n_prev_tokens = {}
prompt_tokens = {}
completion_tokens = {}
try:
async for content in tokenizer_manager.generate_request(
adapted_request, raw_request
):
prompt_tokens = content["meta_info"]["prompt_tokens"]
completion_tokens = content["meta_info"]["completion_tokens"]
index = content["index"]
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
n_prev_token = n_prev_tokens.get(index, 0)
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
completion_tokens[index] = content["meta_info"]["completion_tokens"]
if request.logprobs:
logprobs = to_openai_style_logprobs(
output_token_logprobs=content["meta_info"][
@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
# First chunk with role
is_first = False
choice_data = ChatCompletionResponseStreamChoice(
index=0,
index=index,
delta=DeltaMessage(role="assistant"),
finish_reason=format_finish_reason(
content["meta_info"]["finish_reason"]
@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
choice_data = ChatCompletionResponseStreamChoice(
index=0,
index=index,
delta=DeltaMessage(content=delta),
finish_reason=format_finish_reason(
content["meta_info"]["finish_reason"]
@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
choices=[choice_data],
model=request.model,
)
is_firsts[index] = is_first
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
yield f"data: {chunk.model_dump_json()}\n\n"
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
for i, tokens in prompt_tokens.items()
if i % request.n == 0
)
total_completion_tokens = sum(
tokens for tokens in completion_tokens.values()
)
usage = UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens=total_prompt_tokens,
completion_tokens=total_completion_tokens,
total_tokens=total_prompt_tokens + total_completion_tokens,
)
final_usage_chunk = ChatCompletionStreamResponse(