Simplify stream_output (#2398)

This commit is contained in:
Lianmin Zheng
2024-12-08 12:27:13 -08:00
committed by GitHub
parent f62055b528
commit a6ca736c8e
9 changed files with 426 additions and 290 deletions

View File

@@ -515,6 +515,9 @@ class Scheduler:
recv_req.input_text,
recv_req.input_ids,
recv_req.sampling_params,
return_logprob=recv_req.return_logprob,
top_logprobs_num=recv_req.top_logprobs_num,
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
)
@@ -558,9 +561,6 @@ class Scheduler:
return
# Copy more attributes
req.return_logprob = recv_req.return_logprob
req.top_logprobs_num = recv_req.top_logprobs_num
req.stream = recv_req.stream
req.logprob_start_len = recv_req.logprob_start_len
if req.logprob_start_len == -1:
@@ -982,7 +982,6 @@ class Scheduler:
continue
if req.is_being_chunked <= 0:
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
@@ -1035,7 +1034,7 @@ class Scheduler:
# being chunked reqs' prefill is not finished
req.is_being_chunked -= 1
self.stream_output(batch.reqs, skip_stream_req)
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids, bid = result
@@ -1065,7 +1064,6 @@ class Scheduler:
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
continue
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
@@ -1073,11 +1071,15 @@ class Scheduler:
self.tree_cache.cache_finished_req(req)
if req.return_logprob:
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
req.output_token_logprobs_val.append(next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_id)
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
req.output_top_logprobs_val.append(
logits_output.output_top_logprobs_val[i]
)
req.output_top_logprobs_idx.append(
logits_output.output_top_logprobs_idx[i]
)
if req.grammar is not None:
req.grammar.accept_token(next_token_id)
@@ -1088,7 +1090,7 @@ class Scheduler:
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.stream_output(batch.reqs)
self.stream_output(batch.reqs, batch.return_logprob)
self.token_to_kv_pool.free_group_end()
@@ -1108,9 +1110,8 @@ class Scheduler:
output: LogitsProcessorOutput,
):
"""Attach logprobs to the return values."""
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
req.output_token_logprobs_idx.append(next_token_ids[i])
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
@@ -1118,173 +1119,195 @@ class Scheduler:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.input_token_logprobs is None:
input_token_logprobs = output.input_token_logprobs[
if req.input_token_logprobs_val is None:
input_token_logprobs_val = output.input_token_logprobs[
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
]
input_token_ids = req.fill_ids[
input_token_logprobs_idx = req.fill_ids[
len(req.fill_ids)
- num_input_logprobs
+ 1 : len(req.fill_ids)
- req.last_update_decode_tokens
]
# Clip the padded hash values from image tokens.
# Otherwise, it will lead to detokenization errors.
input_token_ids = [
input_token_logprobs_idx = [
x if x < self.model_config.vocab_size - 1 else 0
for x in input_token_ids
for x in input_token_logprobs_idx
]
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
if (
req.logprob_start_len == 0
): # The first token does not have logprob, pad it.
req.input_token_logprobs = [
(None, req.fill_ids[0])
] + req.input_token_logprobs
input_token_logprobs_val = [None] + input_token_logprobs_val
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
req.input_token_logprobs_val = input_token_logprobs_val
req.input_token_logprobs_idx = input_token_logprobs_idx
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs.extend(
list(
zip(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
],
)
)
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
]
)
if req.top_logprobs_num > 0:
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.input_top_logprobs_val is None:
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
if req.logprob_start_len == 0:
req.input_top_logprobs = [None] + req.input_top_logprobs
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
if req.last_update_decode_tokens != 0:
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs.append(output.output_top_logprobs[i])
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
return num_input_logprobs
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info: List[dict] = []
output_finished_reason: List[BaseFinishReason] = []
rids = []
finished_reasons: List[BaseFinishReason] = []
if self.is_generation:
output_vids = []
vids = []
decoded_texts = []
output_read_ids = []
output_read_offsets = []
decode_ids_list = []
read_offsets = []
output_ids = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_stop_trim = []
else: # embedding or reward model
output_embeddings = []
skip_special_tokens = []
spaces_between_special_tokens = []
no_stop_trim = []
prompt_tokens = []
completion_tokens = []
cached_tokens = []
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
if return_logprob:
input_token_logprobs_val = []
input_token_logprobs_idx = []
output_token_logprobs_val = []
output_token_logprobs_idx = []
input_top_logprobs_val = []
input_top_logprobs_idx = []
output_top_logprobs_val = []
output_top_logprobs_idx = []
normalized_prompt_logprob = []
else:
input_token_logprobs_val = input_token_logprobs_idx = (
output_token_logprobs_val
) = output_token_logprobs_idx = input_top_logprobs_val = (
input_top_logprobs_idx
) = output_top_logprobs_val = output_top_logprobs_idx = (
normalized_prompt_logprob
) = None
for req in reqs:
if req is skip_req:
continue
for req in reqs:
if req is skip_req:
continue
# TODO(lianmin): revisit this for overlap + retract + stream
if req.finished() or (
req.stream and (is_stream_iter or len(req.output_ids) == 1)
):
output_rids.append(req.rid)
output_finished_reason.append(req.finished_reason)
if self.is_generation:
output_vids.append(req.vid)
# TODO(lianmin): revisit this for overlap + retract + stream
if (
req.finished()
# If stream, follow the given stream_interval
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
or (not req.stream and len(req.output_ids) % 50 == 0)
):
rids.append(req.rid)
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text)
read_ids, read_offset = req.init_incremental_detokenize()
output_read_ids.append(read_ids)
output_read_offsets.append(read_offset)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
read_offsets.append(read_offset)
if self.skip_tokenizer_init:
output_ids.append(req.output_ids)
output_skip_special_tokens.append(
req.sampling_params.skip_special_tokens
)
output_spaces_between_special_tokens.append(
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
no_stop_trim.append(req.sampling_params.no_stop_trim)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
"completion_tokens": len(req.output_ids),
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
"cached_tokens": req.cached_tokens,
"finish_reason": (
req.finished_reason.to_json()
if req.finished_reason is not None
else None
),
}
if req.return_logprob:
(
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)
else: # embedding or reward model
output_embeddings.append(req.embedding)
meta_info = {
"prompt_tokens": len(req.origin_input_ids),
}
output_meta_info.append(meta_info)
prompt_tokens.append(len(req.origin_input_ids))
completion_tokens.append(len(req.output_ids))
cached_tokens.append(req.cached_tokens)
# Send to detokenizer
if output_rids:
if self.is_generation:
if return_logprob:
input_token_logprobs_val.append(req.input_token_logprobs_val)
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
output_token_logprobs_val.append(req.output_token_logprobs_val)
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
input_top_logprobs_val.append(req.input_top_logprobs_val)
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
output_top_logprobs_val.append(req.output_top_logprobs_val)
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
# Send to detokenizer
if rids:
self.send_to_detokenizer.send_pyobj(
BatchTokenIDOut(
output_rids,
output_vids,
rids,
finished_reasons,
vids,
decoded_texts,
output_read_ids,
output_read_offsets,
decode_ids_list,
read_offsets,
output_ids,
output_skip_special_tokens,
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
output_no_stop_trim,
)
)
else: # embedding or reward model
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(
output_rids,
output_embeddings,
output_meta_info,
output_finished_reason,
skip_special_tokens,
spaces_between_special_tokens,
no_stop_trim,
prompt_tokens,
completion_tokens,
cached_tokens,
input_token_logprobs_val,
input_token_logprobs_idx,
output_token_logprobs_val,
output_token_logprobs_idx,
input_top_logprobs_val,
input_top_logprobs_idx,
output_top_logprobs_val,
output_top_logprobs_idx,
normalized_prompt_logprob,
)
)
else: # embedding or reward model
embeddings = []
prompt_tokens = []
for req in reqs:
assert req.finished()
rids.append(req.rid)
finished_reasons.append(req.finished_reason.to_json())
embeddings.append(req.embedding)
prompt_tokens.append(len(req.origin_input_ids))
self.send_to_detokenizer.send_pyobj(
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
)
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches