Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs (#776)

This commit is contained in:
Lianmin Zheng
2024-07-27 19:50:34 -07:00
committed by GitHub
parent 0a409bd438
commit 30db99b3d9
16 changed files with 188 additions and 184 deletions

View File

@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
decision = choices[np.argmax(normalized_prompt_logprobs)]
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
return (
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
)
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):

View File

@@ -541,16 +541,16 @@ class StreamExecutor:
(
decision,
normalized_prompt_logprobs,
prefill_token_logprobs,
decode_token_logprobs,
input_token_logprobs,
output_token_logprobs,
) = self.backend.select(self, expr.choices, expr.temperature)
if expr.name is not None:
name = expr.name
self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprobs": normalized_prompt_logprobs,
"prefill_token_logprobs": prefill_token_logprobs,
"decode_token_logprobs": decode_token_logprobs,
"input_token_logprobs": input_token_logprobs,
"output_token_logprobs": output_token_logprobs,
}
self.variable_event[name].set()
self.text_ += decision

View File

@@ -22,13 +22,13 @@ class LogitProcessorOutput:
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor
# The logprobs of prefill tokens. shape: [#token, vocab_size]
prefill_token_logprobs: torch.Tensor
# The logprobs of input tokens. shape: [#token, vocab_size]
input_token_logprobs: torch.Tensor
# The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
prefill_top_logprobs: List
# The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
decode_top_logprobs: List
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
input_top_logprobs: List
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
output_top_logprobs: List
@dataclasses.dataclass
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
def _get_normalized_prompt_logprobs(
self, prefill_token_logprobs, logits_metadata: LogitsMetadata
self, input_token_logprobs, logits_metadata: LogitsMetadata
):
logprobs_cumsum = torch.cumsum(
prefill_token_logprobs, dim=0, dtype=torch.float32
)
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
start = logits_metadata.extend_start_loc.clone()
end = start + logits_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
sum_logp = (
logprobs_cumsum[end]
- logprobs_cumsum[start]
+ prefill_token_logprobs[start]
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
)
normalized_prompt_logprobs = sum_logp / (
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
@@ -83,34 +79,34 @@ class LogitsProcessor(nn.Module):
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
# TODO: vectorize the code below
if logits_metadata.forward_mode == ForwardMode.DECODE:
decode_top_logprobs = []
output_top_logprobs = []
for i in range(all_logprobs.shape[0]):
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[i].topk(k)
v_cpu = t.values.tolist()
p_cpu = t.indices.tolist()
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, decode_top_logprobs
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
return None, output_top_logprobs
else:
prefill_top_logprobs, decode_top_logprobs = [], []
input_top_logprobs, output_top_logprobs = [], []
pt = 0
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
if extend_seq_len == 0:
prefill_top_logprobs.append([])
decode_top_logprobs.append([])
input_top_logprobs.append([])
output_top_logprobs.append([])
continue
k = logits_metadata.top_logprobs_nums[i]
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
vs_cpu = t.values.tolist()
ps_cpu = t.indices.tolist()
prefill_top_logprobs.append(
input_top_logprobs.append(
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
)
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
pt += extend_seq_len
return prefill_top_logprobs, decode_top_logprobs
return input_top_logprobs, output_top_logprobs
def forward(
self,
@@ -150,9 +146,9 @@ class LogitsProcessor(nn.Module):
next_token_logits=last_logits,
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
else:
# When logprob is requested, compute the logits for all tokens.
@@ -164,19 +160,19 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
decode_top_logprobs = self.get_top_logprobs(
output_top_logprobs = self.get_top_logprobs(
last_logprobs, logits_metadata
)[1]
else:
decode_top_logprobs = None
output_top_logprobs = None
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=output_top_logprobs,
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
@@ -193,32 +189,32 @@ class LogitsProcessor(nn.Module):
x > 0 for x in logits_metadata.top_logprobs_nums
)
if return_top_logprob:
prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs(
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
all_logprobs, logits_metadata
)
else:
prefill_top_logprobs = decode_top_logprobs = None
input_top_logprobs = output_top_logprobs = None
last_logprobs = all_logprobs[last_index]
# Compute the logprobs and normalized logprobs for the prefill tokens.
# Note that we pad a zero at the end of each sequence for easy computation.
prefill_token_logprobs = all_logprobs[
input_token_logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
]
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
prefill_token_logprobs, logits_metadata
input_token_logprobs, logits_metadata
)
return LogitProcessorOutput(
next_token_logits=last_logits,
next_token_logprobs=last_logprobs,
normalized_prompt_logprobs=normalized_prompt_logprobs,
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_top_logprobs=decode_top_logprobs,
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_top_logprobs=output_top_logprobs,
)

View File

@@ -226,9 +226,9 @@ class CudaGraphRunner:
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
)
# Extract logprobs
@@ -242,7 +242,7 @@ class CudaGraphRunner:
forward_mode=ForwardMode.DECODE,
top_logprobs_nums=batch.top_logprobs_nums,
)
output.decode_top_logprobs = LogitsProcessor.get_top_logprobs(
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
output.next_token_logprobs, logits_metadata
)[1]

View File

@@ -124,10 +124,10 @@ class Req:
self.logprob_start_len = 0
self.top_logprobs_num = 0
self.normalized_prompt_logprob = None
self.prefill_token_logprobs = None
self.prefill_top_logprobs = None
self.decode_token_logprobs = []
self.decode_top_logprobs = []
self.input_token_logprobs = None
self.input_top_logprobs = None
self.output_token_logprobs = []
self.output_top_logprobs = []
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
@@ -244,8 +244,8 @@ class Req:
k = k + 1
else:
break
self.decode_token_logprobs = self.decode_token_logprobs[:k]
self.decode_top_logprobs = self.decode_top_logprobs[:k]
self.output_token_logprobs = self.output_token_logprobs[:k]
self.output_top_logprobs = self.output_top_logprobs[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k

View File

@@ -455,7 +455,7 @@ class ModelTpServer:
torch.arange(len(next_token_ids), device=next_token_ids.device),
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.input_token_logprobs = output.input_token_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
@@ -481,24 +481,24 @@ class ModelTpServer:
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
if req.prefill_token_logprobs is None:
if req.input_token_logprobs is None:
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
req.prefill_token_logprobs = list(
req.input_token_logprobs = list(
zip(
output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
req.input_ids[-req.extend_input_len + 1 :],
)
)
if req.logprob_start_len == 0:
req.prefill_token_logprobs = [
req.input_token_logprobs = [
(None, req.input_ids[0])
] + req.prefill_token_logprobs
] + req.input_token_logprobs
if req.last_update_decode_tokens != 0:
req.decode_token_logprobs.extend(
req.output_token_logprobs.extend(
list(
zip(
output.prefill_token_logprobs[
output.input_token_logprobs[
pt
+ req.extend_input_len
- req.last_update_decode_tokens : pt
@@ -510,21 +510,21 @@ class ModelTpServer:
)
)
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(output.next_token_logprobs[i], next_token_ids[i])
)
if req.top_logprobs_num > 0:
if req.prefill_top_logprobs is None:
req.prefill_top_logprobs = output.prefill_top_logprobs[i]
if req.input_top_logprobs is None:
req.input_top_logprobs = output.input_top_logprobs[i]
if req.logprob_start_len == 0:
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
req.input_top_logprobs = [None] + req.input_top_logprobs
if req.last_update_decode_tokens != 0:
req.decode_top_logprobs.extend(
output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
req.output_top_logprobs.extend(
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
)
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
def cache_filled_batch(self, batch: Batch):
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
@@ -589,11 +589,11 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append(
req.output_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
req.output_top_logprobs.append(output.output_top_logprobs[i])
self.handle_finished_requests(batch)
@@ -645,16 +645,16 @@ class ModelTpServer:
}
if req.return_logprob:
(
meta_info["prefill_token_logprobs"],
meta_info["decode_token_logprobs"],
meta_info["prefill_top_logprobs"],
meta_info["decode_top_logprobs"],
meta_info["input_token_logprobs"],
meta_info["output_token_logprobs"],
meta_info["input_top_logprobs"],
meta_info["output_top_logprobs"],
meta_info["normalized_prompt_logprob"],
) = (
req.prefill_token_logprobs,
req.decode_token_logprobs,
req.prefill_top_logprobs,
req.decode_top_logprobs,
req.input_token_logprobs,
req.output_token_logprobs,
req.input_top_logprobs,
req.output_top_logprobs,
req.normalized_prompt_logprob,
)
output_meta_info.append(meta_info)

View File

@@ -20,7 +20,7 @@ class GenerateReqInput:
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data: Optional[Union[List[str], str]] = None
# The sampling_params.
# The sampling_params. See descriptions below.
sampling_params: Union[List[Dict], Dict] = None
# The request id.
rid: Optional[Union[List[str], str]] = None
@@ -30,7 +30,7 @@ class GenerateReqInput:
logprob_start_len: Optional[Union[List[int], int]] = None
# The number of top logprobs to return.
top_logprobs_num: Optional[Union[List[int], int]] = None
# Whether to detokenize tokens in logprobs.
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs: bool = False
# Whether to stream output.
stream: bool = False

View File

@@ -448,23 +448,23 @@ class TokenizerManager:
return_text_in_logprobs: bool,
):
if return_logprob:
ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
)
ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"]["prefill_top_logprobs"] = (
ret["meta_info"]["input_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"],
ret["meta_info"]["input_top_logprobs"],
return_text_in_logprobs,
)
)
ret["meta_info"]["decode_top_logprobs"] = (
ret["meta_info"]["output_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
)
)
return ret

View File

@@ -54,9 +54,9 @@ class LlamaForClassification(nn.Module):
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
prefill_token_logprobs=torch.ones_like(input_ids),
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=torch.ones_like(input_ids),
input_top_logprobs=None,
output_top_logprobs=None,
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

View File

@@ -140,29 +140,29 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs:
# The first chunk and echo is enabled.
if not stream_buffer and request.echo:
prefill_token_logprobs = content["meta_info"][
"prefill_token_logprobs"
input_token_logprobs = content["meta_info"][
"input_token_logprobs"
]
prefill_top_logprobs = content["meta_info"][
"prefill_top_logprobs"
input_top_logprobs = content["meta_info"][
"input_top_logprobs"
]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=content["meta_info"][
"decode_token_logprobs"
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=content["meta_info"][
"output_token_logprobs"
][n_prev_token:],
decode_top_logprobs=content["meta_info"][
"decode_top_logprobs"
output_top_logprobs=content["meta_info"][
"output_top_logprobs"
][n_prev_token:],
)
n_prev_token = len(
content["meta_info"]["decode_token_logprobs"]
content["meta_info"]["output_token_logprobs"]
)
else:
logprobs = None
@@ -218,17 +218,17 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
if request.logprobs:
if request.echo:
prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"]
prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"]
input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"]
input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"]
else:
prefill_token_logprobs = None
prefill_top_logprobs = None
input_token_logprobs = None
input_top_logprobs = None
logprobs = to_openai_style_logprobs(
prefill_token_logprobs=prefill_token_logprobs,
prefill_top_logprobs=prefill_top_logprobs,
decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"],
decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"],
input_token_logprobs=input_token_logprobs,
input_top_logprobs=input_top_logprobs,
output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
)
else:
logprobs = None
@@ -401,10 +401,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
def to_openai_style_logprobs(
prefill_token_logprobs=None,
decode_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=None,
input_token_logprobs=None,
output_token_logprobs=None,
input_top_logprobs=None,
output_top_logprobs=None,
):
ret_logprobs = LogProbs()
@@ -425,13 +425,13 @@ def to_openai_style_logprobs(
else:
ret_logprobs.top_logprobs.append(None)
if prefill_token_logprobs is not None:
append_token_logprobs(prefill_token_logprobs)
if decode_token_logprobs is not None:
append_token_logprobs(decode_token_logprobs)
if prefill_top_logprobs is not None:
append_top_logprobs(prefill_top_logprobs)
if decode_top_logprobs is not None:
append_top_logprobs(decode_top_logprobs)
if input_token_logprobs is not None:
append_token_logprobs(input_token_logprobs)
if output_token_logprobs is not None:
append_token_logprobs(output_token_logprobs)
if input_top_logprobs is not None:
append_top_logprobs(input_top_logprobs)
if output_top_logprobs is not None:
append_top_logprobs(output_top_logprobs)
return ret_logprobs