Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)
This commit is contained in:
@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
]
|
||||
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
||||
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
||||
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
|
||||
return (
|
||||
decision,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_token_logprobs,
|
||||
decode_token_logprobs,
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
)
|
||||
|
||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||
|
||||
@@ -541,16 +541,16 @@ class StreamExecutor:
|
||||
(
|
||||
decision,
|
||||
normalized_prompt_logprobs,
|
||||
prefill_token_logprobs,
|
||||
decode_token_logprobs,
|
||||
input_token_logprobs,
|
||||
output_token_logprobs,
|
||||
) = self.backend.select(self, expr.choices, expr.temperature)
|
||||
if expr.name is not None:
|
||||
name = expr.name
|
||||
self.variables[name] = decision
|
||||
self.meta_info[name] = {
|
||||
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||
"prefill_token_logprobs": prefill_token_logprobs,
|
||||
"decode_token_logprobs": decode_token_logprobs,
|
||||
"input_token_logprobs": input_token_logprobs,
|
||||
"output_token_logprobs": output_token_logprobs,
|
||||
}
|
||||
self.variable_event[name].set()
|
||||
self.text_ += decision
|
||||
|
||||
@@ -22,13 +22,13 @@ class LogitProcessorOutput:
|
||||
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor
|
||||
# The logprobs of prefill tokens. shape: [#token, vocab_size]
|
||||
prefill_token_logprobs: torch.Tensor
|
||||
# The logprobs of input tokens. shape: [#token, vocab_size]
|
||||
input_token_logprobs: torch.Tensor
|
||||
|
||||
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
prefill_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
decode_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
input_top_logprobs: List
|
||||
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
||||
output_top_logprobs: List
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
def _get_normalized_prompt_logprobs(
|
||||
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
|
||||
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(
|
||||
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||
)
|
||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||
|
||||
start = logits_metadata.extend_start_loc.clone()
|
||||
end = start + logits_metadata.extend_seq_lens - 2
|
||||
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end]
|
||||
- logprobs_cumsum[start]
|
||||
+ prefill_token_logprobs[start]
|
||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (
|
||||
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||
@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
|
||||
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
||||
# TODO: vectorize the code below
|
||||
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
||||
decode_top_logprobs = []
|
||||
output_top_logprobs = []
|
||||
for i in range(all_logprobs.shape[0]):
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[i].topk(k)
|
||||
v_cpu = t.values.tolist()
|
||||
p_cpu = t.indices.tolist()
|
||||
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, decode_top_logprobs
|
||||
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||
return None, output_top_logprobs
|
||||
else:
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
input_top_logprobs, output_top_logprobs = [], []
|
||||
pt = 0
|
||||
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
input_top_logprobs.append([])
|
||||
output_top_logprobs.append([])
|
||||
continue
|
||||
k = logits_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
input_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_len
|
||||
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
return input_top_logprobs, output_top_logprobs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
else:
|
||||
# When logprob is requested, compute the logits for all tokens.
|
||||
@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
decode_top_logprobs = self.get_top_logprobs(
|
||||
output_top_logprobs = self.get_top_logprobs(
|
||||
last_logprobs, logits_metadata
|
||||
)[1]
|
||||
else:
|
||||
decode_top_logprobs = None
|
||||
output_top_logprobs = None
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
else:
|
||||
all_logits = torch.matmul(hidden_states, weight.T)
|
||||
@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
|
||||
x > 0 for x in logits_metadata.top_logprobs_nums
|
||||
)
|
||||
if return_top_logprob:
|
||||
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
|
||||
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
||||
all_logprobs, logits_metadata
|
||||
)
|
||||
else:
|
||||
prefill_top_logprobs = decode_top_logprobs = None
|
||||
input_top_logprobs = output_top_logprobs = None
|
||||
|
||||
last_logprobs = all_logprobs[last_index]
|
||||
|
||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||
prefill_token_logprobs = all_logprobs[
|
||||
input_token_logprobs = all_logprobs[
|
||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||
]
|
||||
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
prefill_token_logprobs, logits_metadata
|
||||
input_token_logprobs, logits_metadata
|
||||
)
|
||||
|
||||
return LogitProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
next_token_logprobs=last_logprobs,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_top_logprobs=decode_top_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_top_logprobs=output_top_logprobs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -226,9 +226,9 @@ class CudaGraphRunner:
|
||||
next_token_logits=output.next_token_logits[:raw_bs],
|
||||
next_token_logprobs=None,
|
||||
normalized_prompt_logprobs=None,
|
||||
prefill_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
# Extract logprobs
|
||||
@@ -242,7 +242,7 @@ class CudaGraphRunner:
|
||||
forward_mode=ForwardMode.DECODE,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
)
|
||||
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
||||
output.next_token_logprobs, logits_metadata
|
||||
)[1]
|
||||
|
||||
|
||||
@@ -124,10 +124,10 @@ class Req:
|
||||
self.logprob_start_len = 0
|
||||
self.top_logprobs_num = 0
|
||||
self.normalized_prompt_logprob = None
|
||||
self.prefill_token_logprobs = None
|
||||
self.prefill_top_logprobs = None
|
||||
self.decode_token_logprobs = []
|
||||
self.decode_top_logprobs = []
|
||||
self.input_token_logprobs = None
|
||||
self.input_top_logprobs = None
|
||||
self.output_token_logprobs = []
|
||||
self.output_top_logprobs = []
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
@@ -244,8 +244,8 @@ class Req:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
||||
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
||||
self.output_token_logprobs = self.output_token_logprobs[:k]
|
||||
self.output_top_logprobs = self.output_top_logprobs[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
|
||||
@@ -455,7 +455,7 @@ class ModelTpServer:
|
||||
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
||||
next_token_ids,
|
||||
].tolist()
|
||||
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
|
||||
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
||||
output.normalized_prompt_logprobs = (
|
||||
output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
@@ -481,24 +481,24 @@ class ModelTpServer:
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.prefill_token_logprobs is None:
|
||||
if req.input_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
req.input_token_logprobs = list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
req.input_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
] + req.input_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_token_logprobs.extend(
|
||||
req.output_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
output.prefill_token_logprobs[
|
||||
output.input_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
- req.last_update_decode_tokens : pt
|
||||
@@ -510,21 +510,21 @@ class ModelTpServer:
|
||||
)
|
||||
)
|
||||
|
||||
req.decode_token_logprobs.append(
|
||||
req.output_token_logprobs.append(
|
||||
(output.next_token_logprobs[i], next_token_ids[i])
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
if req.prefill_top_logprobs is None:
|
||||
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
|
||||
if req.input_top_logprobs is None:
|
||||
req.input_top_logprobs = output.input_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
req.input_top_logprobs = [None] + req.input_top_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_top_logprobs.extend(
|
||||
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
req.output_top_logprobs.extend(
|
||||
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
)
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
def cache_filled_batch(self, batch: Batch):
|
||||
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
||||
@@ -589,11 +589,11 @@ class ModelTpServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
req.decode_token_logprobs.append(
|
||||
req.output_token_logprobs.append(
|
||||
(next_token_logprobs[i], next_token_id)
|
||||
)
|
||||
if req.top_logprobs_num > 0:
|
||||
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
|
||||
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
||||
|
||||
self.handle_finished_requests(batch)
|
||||
|
||||
@@ -645,16 +645,16 @@ class ModelTpServer:
|
||||
}
|
||||
if req.return_logprob:
|
||||
(
|
||||
meta_info["prefill_token_logprobs"],
|
||||
meta_info["decode_token_logprobs"],
|
||||
meta_info["prefill_top_logprobs"],
|
||||
meta_info["decode_top_logprobs"],
|
||||
meta_info["input_token_logprobs"],
|
||||
meta_info["output_token_logprobs"],
|
||||
meta_info["input_top_logprobs"],
|
||||
meta_info["output_top_logprobs"],
|
||||
meta_info["normalized_prompt_logprob"],
|
||||
) = (
|
||||
req.prefill_token_logprobs,
|
||||
req.decode_token_logprobs,
|
||||
req.prefill_top_logprobs,
|
||||
req.decode_top_logprobs,
|
||||
req.input_token_logprobs,
|
||||
req.output_token_logprobs,
|
||||
req.input_top_logprobs,
|
||||
req.output_top_logprobs,
|
||||
req.normalized_prompt_logprob,
|
||||
)
|
||||
output_meta_info.append(meta_info)
|
||||
|
||||
@@ -20,7 +20,7 @@ class GenerateReqInput:
|
||||
# The image input. It can be a file name, a url, or base64 encoded string.
|
||||
# See also python/sglang/srt/utils.py:load_image.
|
||||
image_data: Optional[Union[List[str], str]] = None
|
||||
# The sampling_params.
|
||||
# The sampling_params. See descriptions below.
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# The request id.
|
||||
rid: Optional[Union[List[str], str]] = None
|
||||
@@ -30,7 +30,7 @@ class GenerateReqInput:
|
||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||
# The number of top logprobs to return.
|
||||
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||
# Whether to detokenize tokens in logprobs.
|
||||
# Whether to detokenize tokens in text in the returned logprobs.
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output.
|
||||
stream: bool = False
|
||||
|
||||
@@ -448,23 +448,23 @@ class TokenizerManager:
|
||||
return_text_in_logprobs: bool,
|
||||
):
|
||||
if return_logprob:
|
||||
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
||||
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
|
||||
if top_logprobs_num > 0:
|
||||
ret["meta_info"]["prefill_top_logprobs"] = (
|
||||
ret["meta_info"]["input_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["prefill_top_logprobs"],
|
||||
ret["meta_info"]["input_top_logprobs"],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
)
|
||||
ret["meta_info"]["decode_top_logprobs"] = (
|
||||
ret["meta_info"]["output_top_logprobs"] = (
|
||||
self.detokenize_top_logprobs_tokens(
|
||||
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
|
||||
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
||||
)
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
|
||||
next_token_logits=scores,
|
||||
next_token_logprobs=scores,
|
||||
normalized_prompt_logprobs=scores,
|
||||
prefill_token_logprobs=torch.ones_like(input_ids),
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
input_token_logprobs=torch.ones_like(input_ids),
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
if request.logprobs:
|
||||
# The first chunk and echo is enabled.
|
||||
if not stream_buffer and request.echo:
|
||||
prefill_token_logprobs = content["meta_info"][
|
||||
"prefill_token_logprobs"
|
||||
input_token_logprobs = content["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
prefill_top_logprobs = content["meta_info"][
|
||||
"prefill_top_logprobs"
|
||||
input_top_logprobs = content["meta_info"][
|
||||
"input_top_logprobs"
|
||||
]
|
||||
else:
|
||||
prefill_token_logprobs = None
|
||||
prefill_top_logprobs = None
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_token_logprobs=content["meta_info"][
|
||||
"decode_token_logprobs"
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=content["meta_info"][
|
||||
"output_token_logprobs"
|
||||
][n_prev_token:],
|
||||
decode_top_logprobs=content["meta_info"][
|
||||
"decode_top_logprobs"
|
||||
output_top_logprobs=content["meta_info"][
|
||||
"output_top_logprobs"
|
||||
][n_prev_token:],
|
||||
)
|
||||
|
||||
n_prev_token = len(
|
||||
content["meta_info"]["decode_token_logprobs"]
|
||||
content["meta_info"]["output_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
if request.logprobs:
|
||||
if request.echo:
|
||||
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
|
||||
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
|
||||
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
|
||||
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
|
||||
else:
|
||||
prefill_token_logprobs = None
|
||||
prefill_top_logprobs = None
|
||||
input_token_logprobs = None
|
||||
input_top_logprobs = None
|
||||
|
||||
logprobs = to_openai_style_logprobs(
|
||||
prefill_token_logprobs=prefill_token_logprobs,
|
||||
prefill_top_logprobs=prefill_top_logprobs,
|
||||
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
|
||||
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs=input_top_logprobs,
|
||||
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
|
||||
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
|
||||
def to_openai_style_logprobs(
|
||||
prefill_token_logprobs=None,
|
||||
decode_token_logprobs=None,
|
||||
prefill_top_logprobs=None,
|
||||
decode_top_logprobs=None,
|
||||
input_token_logprobs=None,
|
||||
output_token_logprobs=None,
|
||||
input_top_logprobs=None,
|
||||
output_top_logprobs=None,
|
||||
):
|
||||
ret_logprobs = LogProbs()
|
||||
|
||||
@@ -425,13 +425,13 @@ def to_openai_style_logprobs(
|
||||
else:
|
||||
ret_logprobs.top_logprobs.append(None)
|
||||
|
||||
if prefill_token_logprobs is not None:
|
||||
append_token_logprobs(prefill_token_logprobs)
|
||||
if decode_token_logprobs is not None:
|
||||
append_token_logprobs(decode_token_logprobs)
|
||||
if prefill_top_logprobs is not None:
|
||||
append_top_logprobs(prefill_top_logprobs)
|
||||
if decode_top_logprobs is not None:
|
||||
append_top_logprobs(decode_top_logprobs)
|
||||
if input_token_logprobs is not None:
|
||||
append_token_logprobs(input_token_logprobs)
|
||||
if output_token_logprobs is not None:
|
||||
append_token_logprobs(output_token_logprobs)
|
||||
if input_top_logprobs is not None:
|
||||
append_top_logprobs(input_top_logprobs)
|
||||
if output_top_logprobs is not None:
|
||||
append_top_logprobs(output_top_logprobs)
|
||||
|
||||
return ret_logprobs
|
||||
|
||||
Reference in New Issue
Block a user