feat: allow streaming for multi-prompt and/or parallel sampling (#1134)
This commit is contained in:
@@ -153,9 +153,6 @@ class TokenizerManager:
|
||||
async for response in self._handle_single_request(obj, request):
|
||||
yield response
|
||||
else:
|
||||
if hasattr(obj, "stream") and obj.stream:
|
||||
raise ValueError("Do not support stream for batch mode.")
|
||||
|
||||
async for response in self._handle_batch_request(obj, request):
|
||||
yield response
|
||||
|
||||
@@ -311,6 +308,7 @@ class TokenizerManager:
|
||||
parallel_sample_num = 1
|
||||
|
||||
# First send out all requests
|
||||
generators = []
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
if j == 0 and parallel_sample_num != 1:
|
||||
@@ -371,42 +369,48 @@ class TokenizerManager:
|
||||
state = ReqState([], False, event)
|
||||
self.rid_to_state[rid] = state
|
||||
|
||||
# Then wait for all responses
|
||||
output_list = []
|
||||
for i in range(batch_size):
|
||||
for j in range(parallel_sample_num):
|
||||
if j == 0 and parallel_sample_num != 1:
|
||||
continue
|
||||
index = i * parallel_sample_num + j
|
||||
if parallel_sample_num != 1:
|
||||
index += batch_size - 1 - i
|
||||
rid = obj.rid[index]
|
||||
state = self.rid_to_state[rid]
|
||||
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(state.event.wait(), timeout=4)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
for rid in obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
if self.is_generation:
|
||||
output_list.append(
|
||||
self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob[index],
|
||||
obj.top_logprobs_num[index],
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
generators.append(
|
||||
self._wait_for_response(
|
||||
event,
|
||||
state,
|
||||
obj,
|
||||
rid,
|
||||
request,
|
||||
index=index,
|
||||
response_index=len(generators),
|
||||
)
|
||||
else:
|
||||
output_list.append(state.out_list[-1])
|
||||
assert state.finished
|
||||
del self.rid_to_state[rid]
|
||||
yield output_list
|
||||
)
|
||||
|
||||
# Then process the responses based on streaming option
|
||||
|
||||
is_stream = hasattr(obj, "stream") and obj.stream
|
||||
|
||||
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
||||
output_list = []
|
||||
|
||||
while tasks:
|
||||
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
gen_index = tasks.index(task)
|
||||
|
||||
try:
|
||||
result = task.result()
|
||||
|
||||
if is_stream:
|
||||
yield result
|
||||
else:
|
||||
output_list.append(result)
|
||||
|
||||
tasks[gen_index] = asyncio.create_task(
|
||||
generators[gen_index].__anext__()
|
||||
)
|
||||
except StopAsyncIteration:
|
||||
del generators[gen_index]
|
||||
del tasks[gen_index]
|
||||
|
||||
if not is_stream:
|
||||
yield output_list
|
||||
|
||||
def _validate_input_length(self, input_ids: List[int]):
|
||||
if len(input_ids) >= self.context_len:
|
||||
@@ -437,26 +441,35 @@ class TokenizerManager:
|
||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||
rid: str,
|
||||
request,
|
||||
index: int = None,
|
||||
response_index: int = 0,
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout=4)
|
||||
except asyncio.TimeoutError:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(rid)
|
||||
for rid in [obj.rid] if obj.is_single else obj.rid:
|
||||
self.abort_request(rid)
|
||||
raise ValueError(f"Abort request {rid}")
|
||||
continue
|
||||
|
||||
if self.is_generation:
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_logprob if index is None else obj.return_logprob[index],
|
||||
(
|
||||
obj.top_logprobs_num
|
||||
if index is None
|
||||
else obj.top_logprobs_num[index]
|
||||
),
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, EmbeddingReqInput)
|
||||
out = state.out_list[-1]
|
||||
|
||||
out["index"] = response_index
|
||||
|
||||
# Log requests
|
||||
if self.server_args.log_requests and state.finished:
|
||||
if obj.text is None:
|
||||
|
||||
@@ -277,6 +277,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
request_data = json.loads(line)
|
||||
file_request_list.append(request_data)
|
||||
body = request_data["body"]
|
||||
|
||||
# Although streaming is supported for standalone completions, it is not supported in
|
||||
# batch mode (multiple completions in single request).
|
||||
if body.get("stream", False):
|
||||
raise ValueError("Streaming requests are not supported in batch mode")
|
||||
|
||||
if end_point == "/v1/chat/completions":
|
||||
all_requests.append(ChatCompletionRequest(**body))
|
||||
elif end_point == "/v1/completions":
|
||||
@@ -592,27 +598,45 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
if adapted_request.stream:
|
||||
|
||||
async def generate_stream_resp():
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content["index"]
|
||||
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
|
||||
text = content["text"]
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
|
||||
if not stream_buffer: # The first chunk
|
||||
if request.echo:
|
||||
if isinstance(request.prompt, str):
|
||||
# for the case of single str prompts
|
||||
prompts = request.prompt
|
||||
elif isinstance(request.prompt, list) and isinstance(
|
||||
request.prompt[0], int
|
||||
):
|
||||
prompts = tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
elif isinstance(request.prompt, list):
|
||||
if isinstance(request.prompt[0], str):
|
||||
# for the case of multiple str prompts
|
||||
prompts = request.prompt[index // request.n]
|
||||
elif isinstance(request.prompt[0], int):
|
||||
# for the case of single token ids prompt
|
||||
prompts = tokenizer_manager.tokenizer.decode(
|
||||
request.prompt, skip_special_tokens=True
|
||||
)
|
||||
elif isinstance(request.prompt[0], list) and isinstance(
|
||||
request.prompt[0][0], int
|
||||
):
|
||||
# for the case of multiple token ids prompts
|
||||
prompts = tokenizer_manager.tokenizer.decode(
|
||||
request.prompt[index // request.n],
|
||||
skip_special_tokens=True,
|
||||
)
|
||||
|
||||
# Prepend prompt in response text.
|
||||
text = prompts + text
|
||||
@@ -649,7 +673,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
choice_data = CompletionResponseStreamChoice(
|
||||
index=0,
|
||||
index=index,
|
||||
text=delta,
|
||||
logprobs=logprobs,
|
||||
finish_reason=format_finish_reason(
|
||||
@@ -662,12 +686,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
for i, tokens in prompt_tokens.items()
|
||||
if i % request.n == 0
|
||||
)
|
||||
total_completion_tokens = sum(
|
||||
tokens for tokens in completion_tokens.values()
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = CompletionStreamResponse(
|
||||
@@ -914,16 +950,23 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
if adapted_request.stream:
|
||||
|
||||
async def generate_stream_resp():
|
||||
is_first = True
|
||||
|
||||
stream_buffer = ""
|
||||
n_prev_token = 0
|
||||
is_firsts = {}
|
||||
stream_buffers = {}
|
||||
n_prev_tokens = {}
|
||||
prompt_tokens = {}
|
||||
completion_tokens = {}
|
||||
try:
|
||||
async for content in tokenizer_manager.generate_request(
|
||||
adapted_request, raw_request
|
||||
):
|
||||
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens = content["meta_info"]["completion_tokens"]
|
||||
index = content["index"]
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
n_prev_token = n_prev_tokens.get(index, 0)
|
||||
|
||||
prompt_tokens[index] = content["meta_info"]["prompt_tokens"]
|
||||
completion_tokens[index] = content["meta_info"]["completion_tokens"]
|
||||
if request.logprobs:
|
||||
logprobs = to_openai_style_logprobs(
|
||||
output_token_logprobs=content["meta_info"][
|
||||
@@ -973,7 +1016,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
# First chunk with role
|
||||
is_first = False
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
index=index,
|
||||
delta=DeltaMessage(role="assistant"),
|
||||
finish_reason=format_finish_reason(
|
||||
content["meta_info"]["finish_reason"]
|
||||
@@ -991,7 +1034,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=format_finish_reason(
|
||||
content["meta_info"]["finish_reason"]
|
||||
@@ -1003,12 +1046,25 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
is_firsts[index] = is_first
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
for i, tokens in prompt_tokens.items()
|
||||
if i % request.n == 0
|
||||
)
|
||||
total_completion_tokens = sum(
|
||||
tokens for tokens in completion_tokens.values()
|
||||
)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=total_prompt_tokens + total_completion_tokens,
|
||||
)
|
||||
|
||||
final_usage_chunk = ChatCompletionStreamResponse(
|
||||
|
||||
Reference in New Issue
Block a user