Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)
This commit is contained in:
@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
|
||||
}
|
||||
obj = self._generate_http_request(s, data)
|
||||
|
||||
normalized_prompt_logprobs = [
|
||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||
]
|
||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||
normalized_prompt_logprobs = [
|
||||
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
||||
for r in obj
|
||||
]
|
||||
|
||||
# Remove extra token if no token healing occurred
|
||||
for i in range(len(input_token_logprobs)):
|
||||
@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
|
||||
def _assert_success(self, res):
|
||||
if res.status_code != 200:
|
||||
raise RuntimeError(res.json())
|
||||
|
||||
|
||||
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||
values = [x[0] for x in input_logprobs if x[0]]
|
||||
return sum(values) / len(values)
|
||||
|
||||
@@ -50,8 +50,6 @@ class LogitsProcessorOutput:
|
||||
next_token_top_logprobs_idx: Optional[List] = None
|
||||
|
||||
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||
# The normlaized logprobs of prompts. shape: [#seq]
|
||||
normalized_prompt_logprobs: torch.Tensor = None
|
||||
# The logprobs of input tokens. shape: [#token]
|
||||
input_token_logprobs: torch.Tensor = None
|
||||
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
||||
@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module):
|
||||
else:
|
||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||
|
||||
# Compute the normalized logprobs for the requested tokens.
|
||||
# Note that we pad a zero at the end for easy batching.
|
||||
input_token_logprobs = input_logprobs[
|
||||
torch.arange(input_logprobs.shape[0], device="cuda"),
|
||||
torch.cat(
|
||||
@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module):
|
||||
]
|
||||
),
|
||||
]
|
||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||
input_token_logprobs,
|
||||
logits_metadata,
|
||||
)
|
||||
|
||||
return LogitsProcessorOutput(
|
||||
next_token_logits=last_logits,
|
||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
||||
input_token_logprobs=input_token_logprobs,
|
||||
input_top_logprobs_val=input_top_logprobs_val,
|
||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||
@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module):
|
||||
if self.do_tensor_parallel_all_gather:
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
|
||||
# Compute the normalized logprobs for the requested tokens.
|
||||
# Note that we pad a zero at the end for easy batching.
|
||||
logits = logits[:, : self.config.vocab_size].float()
|
||||
|
||||
if self.final_logit_softcapping:
|
||||
@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
return logits
|
||||
|
||||
@staticmethod
|
||||
def _get_normalized_prompt_logprobs(
|
||||
input_token_logprobs: torch.Tensor,
|
||||
logits_metadata: LogitsMetadata,
|
||||
):
|
||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
||||
pruned_lens = torch.tensor(
|
||||
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
||||
)
|
||||
|
||||
start = torch.zeros_like(pruned_lens)
|
||||
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
||||
end = torch.clamp(
|
||||
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
||||
)
|
||||
sum_logp = (
|
||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
||||
)
|
||||
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
||||
return normalized_prompt_logprobs
|
||||
|
||||
@staticmethod
|
||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||
max_k = max(logits_metadata.top_logprobs_nums)
|
||||
|
||||
@@ -191,7 +191,6 @@ class DetokenizerManager:
|
||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -340,7 +340,6 @@ class BatchTokenIDOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
normalized_prompt_logprob: List[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -366,7 +365,6 @@ class BatchStrOut:
|
||||
input_top_logprobs_idx: List[List]
|
||||
output_top_logprobs_val: List[List]
|
||||
output_top_logprobs_idx: List[List]
|
||||
normalized_prompt_logprob: List[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -280,7 +280,6 @@ class Req:
|
||||
self.top_logprobs_num = top_logprobs_num
|
||||
|
||||
# Logprobs (return value)
|
||||
self.normalized_prompt_logprob = None
|
||||
self.input_token_logprobs_val = None
|
||||
self.input_token_logprobs_idx = None
|
||||
self.input_top_logprobs_val = None
|
||||
@@ -344,9 +343,6 @@ class Req:
|
||||
max_prefix_len = min(max_prefix_len, input_len - 1)
|
||||
|
||||
if self.return_logprob:
|
||||
if self.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
max_prefix_len = min(max_prefix_len, input_len - 2)
|
||||
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
||||
|
||||
max_prefix_len = max(max_prefix_len, 0)
|
||||
|
||||
@@ -433,7 +433,6 @@ class PrefillAdder:
|
||||
or input_tokens <= self.rem_chunk_tokens
|
||||
or (
|
||||
req.return_logprob
|
||||
and req.normalized_prompt_logprob is None
|
||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||
)
|
||||
):
|
||||
|
||||
@@ -1038,9 +1038,6 @@ class Scheduler:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
|
||||
# Check finish conditions
|
||||
logprob_pt = 0
|
||||
@@ -1188,9 +1185,6 @@ class Scheduler:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
||||
|
||||
if req.input_token_logprobs_val is None:
|
||||
input_token_logprobs_val = output.input_token_logprobs[
|
||||
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||
@@ -1288,15 +1282,12 @@ class Scheduler:
|
||||
input_top_logprobs_idx = []
|
||||
output_top_logprobs_val = []
|
||||
output_top_logprobs_idx = []
|
||||
normalized_prompt_logprob = []
|
||||
else:
|
||||
input_token_logprobs_val = input_token_logprobs_idx = (
|
||||
output_token_logprobs_val
|
||||
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
||||
input_top_logprobs_idx
|
||||
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
||||
normalized_prompt_logprob
|
||||
) = None
|
||||
) = output_top_logprobs_val = output_top_logprobs_idx = None
|
||||
|
||||
for req in reqs:
|
||||
if req is skip_req:
|
||||
@@ -1343,7 +1334,6 @@ class Scheduler:
|
||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
||||
|
||||
# Send to detokenizer
|
||||
if rids:
|
||||
@@ -1370,7 +1360,6 @@ class Scheduler:
|
||||
input_top_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
normalized_prompt_logprob,
|
||||
)
|
||||
)
|
||||
else: # embedding or reward model
|
||||
|
||||
@@ -796,9 +796,6 @@ class TokenizerManager:
|
||||
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
||||
return_text_in_logprobs,
|
||||
)
|
||||
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
||||
recv_obj_index
|
||||
]
|
||||
|
||||
if top_logprobs_num > 0:
|
||||
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||
|
||||
@@ -151,11 +151,6 @@ class TpModelWorkerClient:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
)
|
||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||
copy_done.record()
|
||||
|
||||
@@ -174,9 +169,6 @@ class TpModelWorkerClient:
|
||||
logits_output.input_token_logprobs = (
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
logits_output.normalized_prompt_logprobs = (
|
||||
logits_output.normalized_prompt_logprobs.tolist()
|
||||
)
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
return logits_output, next_token_ids
|
||||
|
||||
|
||||
@@ -535,7 +535,7 @@ def test_hellaswag_select():
|
||||
|
||||
# Compute accuracy
|
||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.01
|
||||
assert np.abs(accuracy_gen - accuracy) < 0.05
|
||||
assert np.abs(latency_gen - latency) < 1
|
||||
|
||||
return accuracy, latency
|
||||
|
||||
Reference in New Issue
Block a user