Simplify stream_output (#2398)
This commit is contained in:
@@ -515,6 +515,9 @@ class Scheduler:
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
return_logprob=recv_req.return_logprob,
|
||||
top_logprobs_num=recv_req.top_logprobs_num,
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
)
|
||||
@@ -558,9 +561,6 @@ class Scheduler:
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
if req.logprob_start_len == -1:
|
||||
@@ -982,7 +982,6 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
if req.is_being_chunked <= 0:
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
@@ -1035,7 +1034,7 @@ class Scheduler:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
self.stream_output(batch.reqs, skip_stream_req)
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1065,7 +1064,6 @@ class Scheduler:
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
@@ -1073,11 +1071,15 @@ class Scheduler:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
||||
req.output_token_logprobs_idx.append(next_token_id)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
req.output_top_logprobs_val.append(
|
||||
logits_output.output_top_logprobs_val[i]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
logits_output.output_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
@@ -1088,7 +1090,7 @@ class Scheduler:
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
@@ -1108,9 +1110,8 @@ class Scheduler:
|
||||
output: LogitsProcessorOutput,
|
||||
):
|
||||
"""Attach logprobs to the return values."""
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
||||
req.output_token_logprobs_idx.append(next_token_ids[i])
|
||||
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||
@@ -1118,173 +1119,195 @@ class Scheduler:
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.input_token_logprobs is None:
|
||||
input_token_logprobs = output.input_token_logprobs[
|
||||
if req.input_token_logprobs_val is None:
|
||||
input_token_logprobs_val = output.input_token_logprobs[
|
||||
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||
]
|
||||
input_token_ids = req.fill_ids[
|
||||
|
||||
input_token_logprobs_idx = req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- num_input_logprobs
|
||||
+ 1 : len(req.fill_ids)
|
||||
- req.last_update_decode_tokens
|
||||
]
|
||||
|
||||
# Clip the padded hash values from image tokens.
|
||||
# Otherwise, it will lead to detokenization errors.
|
||||
input_token_ids = [
|
||||
input_token_logprobs_idx = [
|
||||
x if x < self.model_config.vocab_size - 1 else 0
|
||||
for x in input_token_ids
|
||||
for x in input_token_logprobs_idx
|
||||
]
|
||||
|
||||
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
||||
|
||||
if (
|
||||
req.logprob_start_len == 0
|
||||
): # The first token does not have logprob, pad it.
|
||||
req.input_token_logprobs = [
|
||||
(None, req.fill_ids[0])
|
||||
] + req.input_token_logprobs
|
||||
input_token_logprobs_val = [None] + input_token_logprobs_val
|
||||
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
|
||||
|
||||
req.input_token_logprobs_val = input_token_logprobs_val
|
||||
req.input_token_logprobs_idx = input_token_logprobs_idx
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
# Some decode tokens are re-computed in an extend batch
|
||||
req.output_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
],
|
||||
)
|
||||
)
|
||||
req.output_token_logprobs_val.extend(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
)
|
||||
req.output_token_logprobs_idx.extend(
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
]
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.input_top_logprobs is None:
|
||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||
if req.input_top_logprobs_val is None:
|
||||
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
|
||||
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.input_top_logprobs = [None] + req.input_top_logprobs
|
||||
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
|
||||
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.output_top_logprobs.extend(
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
||||
req.output_top_logprobs_val.extend(
|
||||
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
req.output_top_logprobs_idx.extend(
|
||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
|
||||
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
|
||||
def stream_output(
|
||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
||||
):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info: List[dict] = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
rids = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
if self.is_generation:
|
||||
output_vids = []
|
||||
vids = []
|
||||
decoded_texts = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
decode_ids_list = []
|
||||
read_offsets = []
|
||||
output_ids = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
skip_special_tokens = []
|
||||
spaces_between_special_tokens = []
|
||||
no_stop_trim = []
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
|
||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
input_token_logprobs_idx = []
|
||||
output_token_logprobs_val = []
|
||||
output_token_logprobs_idx = []
|
||||
input_top_logprobs_val = []
|
||||
input_top_logprobs_idx = []
|
||||
output_top_logprobs_val = []
|
||||
output_top_logprobs_idx = []
|
||||
normalized_prompt_logprob = []
|
||||
else:
|
||||
input_token_logprobs_val = input_token_logprobs_idx = (
|
||||
output_token_logprobs_val
|
||||
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
||||
input_top_logprobs_idx
|
||||
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
||||
normalized_prompt_logprob
|
||||
) = None
|
||||
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if req.finished() or (
|
||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_finished_reason.append(req.finished_reason)
|
||||
if self.is_generation:
|
||||
output_vids.append(req.vid)
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if (
|
||||
req.finished()
|
||||
# If stream, follow the given stream_interval
|
||||
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
||||
or (not req.stream and len(req.output_ids) % 50 == 0)
|
||||
):
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
)
|
||||
vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
output_read_offsets.append(read_offset)
|
||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||
decode_ids_list.append(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
if self.skip_tokenizer_init:
|
||||
output_ids.append(req.output_ids)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
output_spaces_between_special_tokens.append(
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"cached_tokens": req.cached_tokens,
|
||||
"finish_reason": (
|
||||
req.finished_reason.to_json()
|
||||
if req.finished_reason is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
else: # embedding or reward model
|
||||
output_embeddings.append(req.embedding)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
completion_tokens.append(len(req.output_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
if return_logprob:
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
||||
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
||||
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
||||
|
||||
# Send to detokenizer
|
||||
if rids:
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
rids,
|
||||
finished_reasons,
|
||||
vids,
|
||||
decoded_texts,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
decode_ids_list,
|
||||
read_offsets,
|
||||
output_ids,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
output_rids,
|
||||
output_embeddings,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
skip_special_tokens,
|
||||
spaces_between_special_tokens,
|
||||
no_stop_trim,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
input_token_logprobs_val,
|
||||
input_token_logprobs_idx,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
embeddings = []
|
||||
prompt_tokens = []
|
||||
for req in reqs:
|
||||
assert req.finished()
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(req.finished_reason.to_json())
|
||||
embeddings.append(req.embedding)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
||||
)
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
# Check if other DP workers have running batches
|
||||
|
||||
Reference in New Issue
Block a user