Simplify stream_output (#2398)
This commit is contained in:
@@ -39,10 +39,12 @@ class LogitsProcessorOutput:
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
input_token_logprobs: torch.Tensor = None
|
||||
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
input_top_logprobs: List = None
|
||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
output_top_logprobs: List = None
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
|
||||
input_top_logprobs_val: List = None
|
||||
input_top_logprobs_idx: List = None
|
||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
|
||||
output_top_logprobs_val: List = None
|
||||
output_top_logprobs_idx: List = None
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -125,12 +127,15 @@ class LogitsProcessor(nn.Module):
|
||||
indices = ret.indices.tolist()
|
||||
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
output_top_logprobs = []
|
||||
output_top_logprobs_val = []
|
||||
output_top_logprobs_idx = []
|
||||
for i, k in enumerate(logits_metadata.top_logprobs_nums):
|
||||
output_top_logprobs.append(list(zip(values[i][:k], indices[i][:k])))
|
||||
return None, output_top_logprobs
|
||||
output_top_logprobs_val.append(values[i][:k])
|
||||
output_top_logprobs_idx.append(indices[i][:k])
|
||||
return None, None, output_top_logprobs_val, output_top_logprobs_idx
|
||||
else:
|
||||
input_top_logprobs, output_top_logprobs = [], []
|
||||
input_top_logprobs_val, input_top_logprobs_idx = [], []
|
||||
output_top_logprobs_val, output_top_logprobs_idx = [], []
|
||||
|
||||
pt = 0
|
||||
for k, pruned_len in zip(
|
||||
@@ -138,27 +143,36 @@ class LogitsProcessor(nn.Module):
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu,
|
||||
):
|
||||
if pruned_len <= 0:
|
||||
input_top_logprobs.append([])
|
||||
output_top_logprobs.append([])
|
||||
input_top_logprobs_val.append([])
|
||||
input_top_logprobs_idx.append([])
|
||||
output_top_logprobs_val.append([])
|
||||
output_top_logprobs_idx.append([])
|
||||
continue
|
||||
|
||||
input_top_logprobs.append(
|
||||
[
|
||||
list(zip(values[pt + j][:k], indices[pt + j][:k]))
|
||||
for j in range(pruned_len - 1)
|
||||
]
|
||||
input_top_logprobs_val.append(
|
||||
[values[pt + j][:k] for j in range(pruned_len - 1)]
|
||||
)
|
||||
output_top_logprobs.append(
|
||||
input_top_logprobs_idx.append(
|
||||
[indices[pt + j][:k] for j in range(pruned_len - 1)]
|
||||
)
|
||||
output_top_logprobs_val.append(
|
||||
list(
|
||||
zip(
|
||||
values[pt + pruned_len - 1][:k],
|
||||
indices[pt + pruned_len - 1][:k],
|
||||
)
|
||||
values[pt + pruned_len - 1][:k],
|
||||
)
|
||||
)
|
||||
output_top_logprobs_idx.append(
|
||||
list(
|
||||
indices[pt + pruned_len - 1][:k],
|
||||
)
|
||||
)
|
||||
pt += pruned_len
|
||||
|
||||
return input_top_logprobs, output_top_logprobs
|
||||
return (
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -193,29 +207,22 @@ class LogitsProcessor(nn.Module):
|
||||
if not logits_metadata.return_logprob:
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
||||
|
||||
if logits_metadata.forward_mode.is_decode():
|
||||
if logits_metadata.return_top_logprob:
|
||||
output_top_logprobs = self.get_top_logprobs(
|
||||
last_logprobs, logits_metadata
|
||||
)[1]
|
||||
output_top_logprobs_val, output_top_logprobs_idx = (
|
||||
self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
|
||||
)
|
||||
else:
|
||||
output_top_logprobs = None
|
||||
output_top_logprobs_val = output_top_logprobs_idx = None
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
output_top_logprobs_val=output_top_logprobs_val,
|
||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
||||
)
|
||||
else:
|
||||
# Slice the requested tokens to compute logprob
|
||||
@@ -246,11 +253,16 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
# Get the logprob of top-k tokens
|
||||
if logits_metadata.return_top_logprob:
|
||||
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
(
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
) = self.get_top_logprobs(all_logprobs, logits_metadata)
|
||||
else:
|
||||
input_top_logprobs = output_top_logprobs = None
|
||||
input_top_logprobs_val = input_top_logprobs_idx = (
|
||||
output_top_logprobs_val
|
||||
) = output_top_logprobs_idx = None
|
||||
|
||||
# Compute the normalized logprobs for the requested tokens.
|
||||
# Note that we pad a zero at the end for easy batching.
|
||||
@@ -273,8 +285,10 @@ class LogitsProcessor(nn.Module):
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
output_top_logprobs_val=output_top_logprobs_val,
|
||||
output_top_logprobs_idx=output_top_logprobs_idx,
|
||||
)
|
||||
|
||||
def _get_logits(
|
||||
|
||||
@@ -17,7 +17,7 @@ import dataclasses
|
||||
import logging
|
||||
import signal
|
||||
from collections import OrderedDict
|
||||
from typing import List, Union
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -76,17 +76,25 @@ class DetokenizerManager:
|
||||
|
||||
self.decode_status = LimitedCapacityDict()
|
||||
|
||||
def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
|
||||
if no_stop_trim:
|
||||
def trim_matched_stop(
|
||||
self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
|
||||
):
|
||||
if no_stop_trim or not finished_reason:
|
||||
return output
|
||||
|
||||
# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
|
||||
pos = output.find(finished_reason.matched)
|
||||
matched = finished_reason.get("matched", None)
|
||||
if not matched:
|
||||
return output
|
||||
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
|
||||
# Trim stop str.
|
||||
if isinstance(matched, str) and isinstance(output, str):
|
||||
pos = output.find(matched)
|
||||
return output[:pos] if pos != -1 else output
|
||||
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
|
||||
output, list
|
||||
):
|
||||
|
||||
# Trim stop token.
|
||||
if isinstance(matched, int) and isinstance(output, list):
|
||||
assert len(output) > 0
|
||||
return output[:-1]
|
||||
return output
|
||||
@@ -125,9 +133,9 @@ class DetokenizerManager:
|
||||
s.decode_ids = recv_obj.decode_ids[i]
|
||||
|
||||
read_ids.append(
|
||||
self.trim_eos(
|
||||
self.trim_matched_stop(
|
||||
s.decode_ids[s.surr_offset :],
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
)
|
||||
@@ -150,7 +158,7 @@ class DetokenizerManager:
|
||||
for i in range(bs):
|
||||
s = self.decode_status[recv_obj.rids[i]]
|
||||
new_text = read_texts[i][len(surr_texts[i]) :]
|
||||
if recv_obj.finished_reason[i] is None:
|
||||
if recv_obj.finished_reasons[i] is None:
|
||||
# Streaming chunk: update the decode status
|
||||
if len(new_text) > 0 and not new_text.endswith("<EFBFBD>"):
|
||||
s.decoded_text = s.decoded_text + new_text
|
||||
@@ -161,9 +169,9 @@ class DetokenizerManager:
|
||||
new_text = find_printable_text(new_text)
|
||||
|
||||
output_strs.append(
|
||||
self.trim_eos(
|
||||
self.trim_matched_stop(
|
||||
s.decoded_text + new_text,
|
||||
recv_obj.finished_reason[i],
|
||||
recv_obj.finished_reasons[i],
|
||||
recv_obj.no_stop_trim[i],
|
||||
)
|
||||
)
|
||||
@@ -171,9 +179,20 @@ class DetokenizerManager:
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
rids=recv_obj.rids,
|
||||
finished_reasons=recv_obj.finished_reasons,
|
||||
output_strs=output_strs,
|
||||
meta_info=recv_obj.meta_info,
|
||||
finished_reason=recv_obj.finished_reason,
|
||||
prompt_tokens=recv_obj.prompt_tokens,
|
||||
completion_tokens=recv_obj.completion_tokens,
|
||||
cached_tokens=recv_obj.cached_tokens,
|
||||
input_token_logprobs_val=recv_obj.input_token_logprobs_val,
|
||||
input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
|
||||
output_token_logprobs_val=recv_obj.output_token_logprobs_val,
|
||||
output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
|
||||
input_top_logprobs_val=recv_obj.input_top_logprobs_val,
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
|
||||
class BatchTokenIDOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# For incremental decoding
|
||||
# The version id to sync decode status with in detokenizer_manager
|
||||
vids: List[int]
|
||||
decoded_texts: List[str]
|
||||
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
|
||||
read_offsets: List[int]
|
||||
# Only used when `--skip-tokenizer-init`
|
||||
output_ids: Optional[List[int]]
|
||||
# Detokenization configs
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
meta_info: List[Dict]
|
||||
finished_reason: List[BaseFinishReason]
|
||||
no_stop_trim: List[bool]
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
input_token_logprobs_idx: List[int]
|
||||
output_token_logprobs_val: List[float]
|
||||
output_token_logprobs_idx: List[int]
|
||||
input_top_logprobs_val: List[List]
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
normalized_prompt_logprob: List[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStrOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
# The finish reason
|
||||
finished_reasons: List[dict]
|
||||
# The output decoded strings
|
||||
output_strs: List[str]
|
||||
# The meta info
|
||||
meta_info: List[Dict]
|
||||
# The finish reason
|
||||
finished_reason: List[BaseFinishReason]
|
||||
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
completion_tokens: List[int]
|
||||
cached_tokens: List[int]
|
||||
# Logprobs
|
||||
input_token_logprobs_val: List[float]
|
||||
input_token_logprobs_idx: List[int]
|
||||
output_token_logprobs_val: List[float]
|
||||
output_token_logprobs_idx: List[int]
|
||||
input_top_logprobs_val: List[List]
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
normalized_prompt_logprob: List[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchEmbeddingOut:
|
||||
# The request id
|
||||
rids: List[str]
|
||||
# The finish reason
|
||||
finished_reasons: List[BaseFinishReason]
|
||||
# The output embedding
|
||||
embeddings: List[List[float]]
|
||||
# The meta info
|
||||
meta_info: List[Dict]
|
||||
# The finish reason
|
||||
finished_reason: List[BaseFinishReason]
|
||||
# Token counts
|
||||
prompt_tokens: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -200,6 +200,9 @@ class Req:
|
||||
origin_input_text: str,
|
||||
origin_input_ids: Tuple[int],
|
||||
sampling_params: SamplingParams,
|
||||
return_logprob: bool = False,
|
||||
top_logprobs_num: int = 0,
|
||||
stream: bool = False,
|
||||
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
input_embeds: Optional[List[List[float]]] = None,
|
||||
@@ -217,10 +220,11 @@ class Req:
|
||||
self.output_ids = [] # Each decode stage's output ids
|
||||
self.fill_ids = None # fill_ids = origin_input_ids + output_ids
|
||||
self.session_id = session_id
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Sampling info
|
||||
self.sampling_params = sampling_params
|
||||
self.lora_path = lora_path
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
# Memory pool info
|
||||
self.req_pool_idx = None
|
||||
@@ -228,8 +232,8 @@ class Req:
|
||||
# Check finish
|
||||
self.tokenizer = None
|
||||
self.finished_reason = None
|
||||
self.stream = False
|
||||
self.to_abort = False
|
||||
self.stream = stream
|
||||
|
||||
# For incremental decoding
|
||||
# ----- | --------- read_ids -------|
|
||||
@@ -241,13 +245,9 @@ class Req:
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||
self.decoded_text = ""
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
|
||||
# The number of decoded tokens for token usage report. Note that
|
||||
# this does not include the jump forward tokens.
|
||||
self.completion_tokens_wo_jump_forward = 0
|
||||
self.decoded_text = ""
|
||||
|
||||
# For multimodal inputs
|
||||
self.image_inputs: Optional[ImageInputs] = None
|
||||
@@ -256,22 +256,34 @@ class Req:
|
||||
self.prefix_indices = []
|
||||
self.extend_input_len = 0
|
||||
self.last_node = None
|
||||
|
||||
# Chunked prefill
|
||||
self.is_being_chunked = 0
|
||||
|
||||
# For retraction
|
||||
self.is_retracted = False
|
||||
|
||||
# Logprobs (arguments)
|
||||
self.return_logprob = False
|
||||
self.return_logprob = return_logprob
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
|
||||
# Logprobs (return value)
|
||||
self.normalized_prompt_logprob = None
|
||||
self.input_token_logprobs = None
|
||||
self.input_top_logprobs = None
|
||||
self.output_token_logprobs = []
|
||||
self.output_top_logprobs = []
|
||||
self.input_token_logprobs_val = None
|
||||
self.input_token_logprobs_idx = None
|
||||
self.input_top_logprobs_val = None
|
||||
self.input_top_logprobs_idx = None
|
||||
|
||||
if return_logprob:
|
||||
self.output_token_logprobs_val = []
|
||||
self.output_token_logprobs_idx = []
|
||||
self.output_top_logprobs_val = []
|
||||
self.output_top_logprobs_idx = []
|
||||
else:
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_idx = (
|
||||
self.output_top_logprobs_val
|
||||
) = self.output_top_logprobs_idx = None
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
@@ -295,8 +307,8 @@ class Req:
|
||||
else:
|
||||
self.image_inputs.merge(image_inputs)
|
||||
|
||||
# whether request reached finished condition
|
||||
def finished(self) -> bool:
|
||||
# Whether request reached finished condition
|
||||
return self.finished_reason is not None
|
||||
|
||||
def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
|
||||
@@ -454,8 +466,10 @@ class Req:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.output_token_logprobs = self.output_token_logprobs[:k]
|
||||
self.output_top_logprobs = self.output_top_logprobs[:k]
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
|
||||
@@ -515,6 +515,9 @@ class Scheduler:
|
||||
recv_req.input_text,
|
||||
recv_req.input_ids,
|
||||
recv_req.sampling_params,
|
||||
return_logprob=recv_req.return_logprob,
|
||||
top_logprobs_num=recv_req.top_logprobs_num,
|
||||
stream=recv_req.stream,
|
||||
lora_path=recv_req.lora_path,
|
||||
input_embeds=recv_req.input_embeds,
|
||||
)
|
||||
@@ -558,9 +561,6 @@ class Scheduler:
|
||||
return
|
||||
|
||||
# Copy more attributes
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||
req.stream = recv_req.stream
|
||||
req.logprob_start_len = recv_req.logprob_start_len
|
||||
|
||||
if req.logprob_start_len == -1:
|
||||
@@ -982,7 +982,6 @@ class Scheduler:
|
||||
continue
|
||||
|
||||
if req.is_being_chunked <= 0:
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
@@ -1035,7 +1034,7 @@ class Scheduler:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_being_chunked -= 1
|
||||
|
||||
self.stream_output(batch.reqs, skip_stream_req)
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def process_batch_result_decode(self, batch: ScheduleBatch, result):
|
||||
logits_output, next_token_ids, bid = result
|
||||
@@ -1065,7 +1064,6 @@ class Scheduler:
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
req.completion_tokens_wo_jump_forward += 1
|
||||
req.output_ids.append(next_token_id)
|
||||
req.check_finished()
|
||||
|
||||
@@ -1073,11 +1071,15 @@ class Scheduler:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
req.output_token_logprobs_val.append(next_token_logprobs[i])
|
||||
req.output_token_logprobs_idx.append(next_token_id)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
|
||||
req.output_top_logprobs_val.append(
|
||||
logits_output.output_top_logprobs_val[i]
|
||||
)
|
||||
req.output_top_logprobs_idx.append(
|
||||
logits_output.output_top_logprobs_idx[i]
|
||||
)
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
@@ -1088,7 +1090,7 @@ class Scheduler:
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
self.stream_output(batch.reqs)
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
|
||||
@@ -1108,9 +1110,8 @@ class Scheduler:
|
||||
output: LogitsProcessorOutput,
|
||||
):
|
||||
"""Attach logprobs to the return values."""
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
req.output_token_logprobs_val.append(output.next_token_logprobs[i])
|
||||
req.output_token_logprobs_idx.append(next_token_ids[i])
|
||||
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||
@@ -1118,173 +1119,195 @@ class Scheduler:
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.input_token_logprobs is None:
|
||||
input_token_logprobs = output.input_token_logprobs[
|
||||
if req.input_token_logprobs_val is None:
|
||||
input_token_logprobs_val = output.input_token_logprobs[
|
||||
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||
]
|
||||
input_token_ids = req.fill_ids[
|
||||
|
||||
input_token_logprobs_idx = req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- num_input_logprobs
|
||||
+ 1 : len(req.fill_ids)
|
||||
- req.last_update_decode_tokens
|
||||
]
|
||||
|
||||
# Clip the padded hash values from image tokens.
|
||||
# Otherwise, it will lead to detokenization errors.
|
||||
input_token_ids = [
|
||||
input_token_logprobs_idx = [
|
||||
x if x < self.model_config.vocab_size - 1 else 0
|
||||
for x in input_token_ids
|
||||
for x in input_token_logprobs_idx
|
||||
]
|
||||
|
||||
req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
|
||||
|
||||
if (
|
||||
req.logprob_start_len == 0
|
||||
): # The first token does not have logprob, pad it.
|
||||
req.input_token_logprobs = [
|
||||
(None, req.fill_ids[0])
|
||||
] + req.input_token_logprobs
|
||||
input_token_logprobs_val = [None] + input_token_logprobs_val
|
||||
input_token_logprobs_idx = [req.fill_ids[0]] + input_token_logprobs_idx
|
||||
|
||||
req.input_token_logprobs_val = input_token_logprobs_val
|
||||
req.input_token_logprobs_idx = input_token_logprobs_idx
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
# Some decode tokens are re-computed in an extend batch
|
||||
req.output_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
],
|
||||
)
|
||||
)
|
||||
req.output_token_logprobs_val.extend(
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ num_input_logprobs
|
||||
- 1
|
||||
],
|
||||
)
|
||||
req.output_token_logprobs_idx.extend(
|
||||
req.fill_ids[
|
||||
len(req.fill_ids)
|
||||
- req.last_update_decode_tokens : len(req.fill_ids)
|
||||
]
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.input_top_logprobs is None:
|
||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||
if req.input_top_logprobs_val is None:
|
||||
req.input_top_logprobs_val = output.input_top_logprobs_val[i]
|
||||
req.input_top_logprobs_idx = output.input_top_logprobs_idx[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.input_top_logprobs = [None] + req.input_top_logprobs
|
||||
req.input_top_logprobs_val = [None] + req.input_top_logprobs_val
|
||||
req.input_top_logprobs_idx = [None] + req.input_top_logprobs_idx
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.output_top_logprobs.extend(
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens :]
|
||||
req.output_top_logprobs_val.extend(
|
||||
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
req.output_top_logprobs_idx.extend(
|
||||
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
|
||||
)
|
||||
req.output_top_logprobs_val.append(output.output_top_logprobs_val[i])
|
||||
req.output_top_logprobs_idx.append(output.output_top_logprobs_idx[i])
|
||||
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req], skip_req: Optional[Req] = None):
|
||||
def stream_output(
|
||||
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
|
||||
):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info: List[dict] = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
rids = []
|
||||
finished_reasons: List[BaseFinishReason] = []
|
||||
|
||||
if self.is_generation:
|
||||
output_vids = []
|
||||
vids = []
|
||||
decoded_texts = []
|
||||
output_read_ids = []
|
||||
output_read_offsets = []
|
||||
decode_ids_list = []
|
||||
read_offsets = []
|
||||
output_ids = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
output_no_stop_trim = []
|
||||
else: # embedding or reward model
|
||||
output_embeddings = []
|
||||
skip_special_tokens = []
|
||||
spaces_between_special_tokens = []
|
||||
no_stop_trim = []
|
||||
prompt_tokens = []
|
||||
completion_tokens = []
|
||||
cached_tokens = []
|
||||
|
||||
is_stream_iter = self.forward_ct_decode % self.stream_interval == 0
|
||||
if return_logprob:
|
||||
input_token_logprobs_val = []
|
||||
input_token_logprobs_idx = []
|
||||
output_token_logprobs_val = []
|
||||
output_token_logprobs_idx = []
|
||||
input_top_logprobs_val = []
|
||||
input_top_logprobs_idx = []
|
||||
output_top_logprobs_val = []
|
||||
output_top_logprobs_idx = []
|
||||
normalized_prompt_logprob = []
|
||||
else:
|
||||
input_token_logprobs_val = input_token_logprobs_idx = (
|
||||
output_token_logprobs_val
|
||||
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
||||
input_top_logprobs_idx
|
||||
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
||||
normalized_prompt_logprob
|
||||
) = None
|
||||
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
continue
|
||||
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if req.finished() or (
|
||||
req.stream and (is_stream_iter or len(req.output_ids) == 1)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
output_finished_reason.append(req.finished_reason)
|
||||
if self.is_generation:
|
||||
output_vids.append(req.vid)
|
||||
# TODO(lianmin): revisit this for overlap + retract + stream
|
||||
if (
|
||||
req.finished()
|
||||
# If stream, follow the given stream_interval
|
||||
or (req.stream and len(req.output_ids) % self.stream_interval == 0)
|
||||
# If not stream, we still want to output some tokens to get the benefit of incremental decoding.
|
||||
or (not req.stream and len(req.output_ids) % 50 == 0)
|
||||
):
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
)
|
||||
vids.append(req.vid)
|
||||
decoded_texts.append(req.decoded_text)
|
||||
read_ids, read_offset = req.init_incremental_detokenize()
|
||||
output_read_ids.append(read_ids)
|
||||
output_read_offsets.append(read_offset)
|
||||
decode_ids, read_offset = req.init_incremental_detokenize()
|
||||
decode_ids_list.append(decode_ids)
|
||||
read_offsets.append(read_offset)
|
||||
if self.skip_tokenizer_init:
|
||||
output_ids.append(req.output_ids)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
)
|
||||
output_spaces_between_special_tokens.append(
|
||||
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
||||
spaces_between_special_tokens.append(
|
||||
req.sampling_params.spaces_between_special_tokens
|
||||
)
|
||||
output_no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
no_stop_trim.append(req.sampling_params.no_stop_trim)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"cached_tokens": req.cached_tokens,
|
||||
"finish_reason": (
|
||||
req.finished_reason.to_json()
|
||||
if req.finished_reason is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
else: # embedding or reward model
|
||||
output_embeddings.append(req.embedding)
|
||||
meta_info = {
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
}
|
||||
output_meta_info.append(meta_info)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
completion_tokens.append(len(req.output_ids))
|
||||
cached_tokens.append(req.cached_tokens)
|
||||
|
||||
# Send to detokenizer
|
||||
if output_rids:
|
||||
if self.is_generation:
|
||||
if return_logprob:
|
||||
input_token_logprobs_val.append(req.input_token_logprobs_val)
|
||||
input_token_logprobs_idx.append(req.input_token_logprobs_idx)
|
||||
output_token_logprobs_val.append(req.output_token_logprobs_val)
|
||||
output_token_logprobs_idx.append(req.output_token_logprobs_idx)
|
||||
input_top_logprobs_val.append(req.input_top_logprobs_val)
|
||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
||||
|
||||
# Send to detokenizer
|
||||
if rids:
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
output_vids,
|
||||
rids,
|
||||
finished_reasons,
|
||||
vids,
|
||||
decoded_texts,
|
||||
output_read_ids,
|
||||
output_read_offsets,
|
||||
decode_ids_list,
|
||||
read_offsets,
|
||||
output_ids,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
output_no_stop_trim,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(
|
||||
output_rids,
|
||||
output_embeddings,
|
||||
output_meta_info,
|
||||
output_finished_reason,
|
||||
skip_special_tokens,
|
||||
spaces_between_special_tokens,
|
||||
no_stop_trim,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
cached_tokens,
|
||||
input_token_logprobs_val,
|
||||
input_token_logprobs_idx,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
input_top_logprobs_val,
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
embeddings = []
|
||||
prompt_tokens = []
|
||||
for req in reqs:
|
||||
assert req.finished()
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(req.finished_reason.to_json())
|
||||
embeddings.append(req.embedding)
|
||||
prompt_tokens.append(len(req.origin_input_ids))
|
||||
self.send_to_detokenizer.send_pyobj(
|
||||
BatchEmbeddingOut(rids, finished_reasons, embeddings, prompt_tokens)
|
||||
)
|
||||
|
||||
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
|
||||
# Check if other DP workers have running batches
|
||||
|
||||
@@ -22,7 +22,7 @@ import signal
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import fastapi
|
||||
import uvloop
|
||||
@@ -76,6 +76,7 @@ class ReqState:
|
||||
out_list: List
|
||||
finished: bool
|
||||
event: asyncio.Event
|
||||
obj: Any
|
||||
|
||||
# For metrics
|
||||
created_time: float
|
||||
@@ -283,7 +284,7 @@ class TokenizerManager:
|
||||
):
|
||||
"""Wait for the response of one request."""
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event, created_time=created_time)
|
||||
state = ReqState([], False, event, obj, created_time=created_time)
|
||||
self.rid_to_state[obj.rid] = state
|
||||
|
||||
while True:
|
||||
@@ -295,15 +296,7 @@ class TokenizerManager:
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
continue
|
||||
|
||||
if isinstance(obj, GenerateReqInput):
|
||||
out = self.convert_logprob_style(
|
||||
state.out_list[-1],
|
||||
obj.return_logprob,
|
||||
obj.top_logprobs_num,
|
||||
obj.return_text_in_logprobs,
|
||||
)
|
||||
else: # isinstance(obj, (EmbeddingReqInput,))
|
||||
out = state.out_list[-1]
|
||||
out = state.out_list[-1]
|
||||
|
||||
state.out_list = []
|
||||
if state.finished:
|
||||
@@ -315,7 +308,13 @@ class TokenizerManager:
|
||||
break
|
||||
|
||||
state.event.clear()
|
||||
yield out
|
||||
|
||||
if obj.stream:
|
||||
yield out
|
||||
else:
|
||||
if request is not None and await request.is_disconnected():
|
||||
self.abort_request(obj.rid)
|
||||
raise ValueError(f"Abort request {obj.rid}")
|
||||
|
||||
async def _handle_batch_request(
|
||||
self,
|
||||
@@ -609,29 +608,55 @@ class TokenizerManager:
|
||||
if state is None:
|
||||
continue
|
||||
|
||||
recv_obj.meta_info[i]["id"] = rid
|
||||
meta_info = {
|
||||
"id": rid,
|
||||
"finish_reason": recv_obj.finished_reasons[i],
|
||||
"prompt_tokens": recv_obj.prompt_tokens[i],
|
||||
}
|
||||
|
||||
if getattr(state.obj, "return_logprob", False):
|
||||
self.convert_logprob_style(
|
||||
meta_info,
|
||||
state.obj.top_logprobs_num,
|
||||
state.obj.return_text_in_logprobs,
|
||||
recv_obj,
|
||||
i,
|
||||
)
|
||||
|
||||
if isinstance(recv_obj, BatchStrOut):
|
||||
out_dict = {
|
||||
"text": recv_obj.output_strs[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"meta_info": {
|
||||
**meta_info,
|
||||
"completion_tokens": recv_obj.completion_tokens[i],
|
||||
"cached_tokens": recv_obj.cached_tokens[i],
|
||||
},
|
||||
}
|
||||
elif isinstance(recv_obj, BatchTokenIDOut):
|
||||
out_dict = {
|
||||
"token_ids": recv_obj.output_ids[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"meta_info": {
|
||||
**meta_info,
|
||||
"completion_tokens": recv_obj.completion_tokens[i],
|
||||
"cached_tokens": recv_obj.cached_tokens[i],
|
||||
},
|
||||
}
|
||||
else:
|
||||
assert isinstance(recv_obj, BatchEmbeddingOut)
|
||||
out_dict = {
|
||||
"embedding": recv_obj.embeddings[i],
|
||||
"meta_info": recv_obj.meta_info[i],
|
||||
"meta_info": meta_info,
|
||||
}
|
||||
state.out_list.append(out_dict)
|
||||
state.finished = recv_obj.finished_reason[i] is not None
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
|
||||
completion_tokens = (
|
||||
recv_obj.completion_tokens[i]
|
||||
if recv_obj.completion_tokens
|
||||
else 0
|
||||
)
|
||||
|
||||
if state.first_token_time is None:
|
||||
state.first_token_time = time.time()
|
||||
@@ -647,7 +672,7 @@ class TokenizerManager:
|
||||
|
||||
if state.finished:
|
||||
self.metrics_collector.inc_prompt_tokens(
|
||||
recv_obj.meta_info[i]["prompt_tokens"]
|
||||
recv_obj.prompt_tokens[i]
|
||||
)
|
||||
self.metrics_collector.inc_generation_tokens(
|
||||
completion_tokens
|
||||
@@ -696,57 +721,73 @@ class TokenizerManager:
|
||||
|
||||
def convert_logprob_style(
|
||||
self,
|
||||
ret: dict,
|
||||
return_logprob: bool,
|
||||
meta_info: dict,
|
||||
top_logprobs_num: int,
|
||||
return_text_in_logprobs: bool,
|
||||
recv_obj: BatchStrOut,
|
||||
recv_obj_index: int,
|
||||
):
|
||||
if return_logprob:
|
||||
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"]["input_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["input_top_logprobs"],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
ret["meta_info"]["output_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
def detokenize_logprob_tokens(
|
||||
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
||||
):
|
||||
# TODO(lianmin): This should run on DetokenizerManager
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
assert self.tokenizer is not None
|
||||
token_ids = [tid for _, tid in token_logprobs]
|
||||
token_texts = self.tokenizer.batch_decode(token_ids)
|
||||
return [
|
||||
(logprob, token_id, token_text)
|
||||
for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
|
||||
meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
recv_obj.input_token_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_token_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
recv_obj.output_token_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
||||
recv_obj_index
|
||||
]
|
||||
|
||||
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
||||
if top_logprobs_num > 0:
|
||||
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.input_top_logprobs_val[recv_obj_index],
|
||||
recv_obj.input_top_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
recv_obj.output_top_logprobs_val[recv_obj_index],
|
||||
recv_obj.output_top_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
|
||||
def detokenize_logprob_tokens(
|
||||
self,
|
||||
token_logprobs_val: List[float],
|
||||
token_logprobs_idx: List[int],
|
||||
decode_to_text: bool,
|
||||
):
|
||||
if not decode_to_text:
|
||||
return [
|
||||
(logprob, token_id, None)
|
||||
for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
|
||||
]
|
||||
else:
|
||||
assert self.tokenizer is not None
|
||||
token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
|
||||
return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
|
||||
|
||||
def detokenize_top_logprobs_tokens(
|
||||
self,
|
||||
token_logprobs_val: List[float],
|
||||
token_logprobs_idx: List[int],
|
||||
decode_to_text: bool,
|
||||
):
|
||||
# TODO: The current implementation only batches the detokenization for top-k tokens per single position.
|
||||
# We should batch all top-k tokens in all positions.
|
||||
for i, token_top_logprobs in enumerate(top_logprobs):
|
||||
if token_top_logprobs:
|
||||
top_logprobs[i] = self.detokenize_logprob_tokens(
|
||||
token_top_logprobs, decode_to_text
|
||||
ret = []
|
||||
for i in range(len(token_logprobs_val)):
|
||||
if token_logprobs_val[i]:
|
||||
ret.append(
|
||||
self.detokenize_logprob_tokens(
|
||||
token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
|
||||
)
|
||||
)
|
||||
return top_logprobs
|
||||
else:
|
||||
ret.append(None)
|
||||
return ret
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
|
||||
@@ -400,9 +400,14 @@ class CudaGraphRunner:
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
||||
)
|
||||
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
(
|
||||
logits_output.output_top_logprobs_val,
|
||||
logits_output.output_top_logprobs_idx,
|
||||
) = LogitsProcessor.get_top_logprobs(
|
||||
next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
)[
|
||||
2:4
|
||||
]
|
||||
else:
|
||||
logits_output = LogitsProcessorOutput(
|
||||
next_token_logits=next_token_logits,
|
||||
|
||||
@@ -720,13 +720,13 @@ def run_and_check_memory_leak(
|
||||
|
||||
# Clean up everything
|
||||
kill_process_tree(process.pid)
|
||||
kill_process_tree(process.pid)
|
||||
stdout.close()
|
||||
stderr.close()
|
||||
if os.path.exists(STDOUT_FILENAME):
|
||||
os.remove(STDOUT_FILENAME)
|
||||
if os.path.exists(STDERR_FILENAME):
|
||||
os.remove(STDERR_FILENAME)
|
||||
kill_process_tree(process.pid)
|
||||
t.join()
|
||||
|
||||
# Assert success
|
||||
@@ -734,7 +734,7 @@ def run_and_check_memory_leak(
|
||||
has_leak = False
|
||||
has_abort = False
|
||||
for line in output_lines:
|
||||
if "The server is fired" in line:
|
||||
if "Uvicorn running" in line:
|
||||
has_new_server = True
|
||||
if "leak" in line:
|
||||
has_leak = True
|
||||
|
||||
Reference in New Issue
Block a user