Remove normalized_prompt_logprobs from the engine to make code easier to maintain (#2902)
This commit is contained in:
@@ -251,11 +251,12 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
}
|
}
|
||||||
obj = self._generate_http_request(s, data)
|
obj = self._generate_http_request(s, data)
|
||||||
|
|
||||||
normalized_prompt_logprobs = [
|
|
||||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
|
||||||
]
|
|
||||||
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
||||||
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
||||||
|
normalized_prompt_logprobs = [
|
||||||
|
compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"])
|
||||||
|
for r in obj
|
||||||
|
]
|
||||||
|
|
||||||
# Remove extra token if no token healing occurred
|
# Remove extra token if no token healing occurred
|
||||||
for i in range(len(input_token_logprobs)):
|
for i in range(len(input_token_logprobs)):
|
||||||
@@ -319,3 +320,8 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
def _assert_success(self, res):
|
def _assert_success(self, res):
|
||||||
if res.status_code != 200:
|
if res.status_code != 200:
|
||||||
raise RuntimeError(res.json())
|
raise RuntimeError(res.json())
|
||||||
|
|
||||||
|
|
||||||
|
def compute_normalized_prompt_logprobs(input_logprobs):
|
||||||
|
values = [x[0] for x in input_logprobs if x[0]]
|
||||||
|
return sum(values) / len(values)
|
||||||
|
|||||||
@@ -50,8 +50,6 @@ class LogitsProcessorOutput:
|
|||||||
next_token_top_logprobs_idx: Optional[List] = None
|
next_token_top_logprobs_idx: Optional[List] = None
|
||||||
|
|
||||||
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
||||||
# The normlaized logprobs of prompts. shape: [#seq]
|
|
||||||
normalized_prompt_logprobs: torch.Tensor = None
|
|
||||||
# The logprobs of input tokens. shape: [#token]
|
# The logprobs of input tokens. shape: [#token]
|
||||||
input_token_logprobs: torch.Tensor = None
|
input_token_logprobs: torch.Tensor = None
|
||||||
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
# The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
|
||||||
@@ -195,8 +193,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
else:
|
else:
|
||||||
input_top_logprobs_val = input_top_logprobs_idx = None
|
input_top_logprobs_val = input_top_logprobs_idx = None
|
||||||
|
|
||||||
# Compute the normalized logprobs for the requested tokens.
|
|
||||||
# Note that we pad a zero at the end for easy batching.
|
|
||||||
input_token_logprobs = input_logprobs[
|
input_token_logprobs = input_logprobs[
|
||||||
torch.arange(input_logprobs.shape[0], device="cuda"),
|
torch.arange(input_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat(
|
torch.cat(
|
||||||
@@ -206,14 +202,9 @@ class LogitsProcessor(nn.Module):
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
|
||||||
input_token_logprobs,
|
|
||||||
logits_metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
return LogitsProcessorOutput(
|
return LogitsProcessorOutput(
|
||||||
next_token_logits=last_logits,
|
next_token_logits=last_logits,
|
||||||
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
|
||||||
input_token_logprobs=input_token_logprobs,
|
input_token_logprobs=input_token_logprobs,
|
||||||
input_top_logprobs_val=input_top_logprobs_val,
|
input_top_logprobs_val=input_top_logprobs_val,
|
||||||
input_top_logprobs_idx=input_top_logprobs_idx,
|
input_top_logprobs_idx=input_top_logprobs_idx,
|
||||||
@@ -237,8 +228,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
if self.do_tensor_parallel_all_gather:
|
if self.do_tensor_parallel_all_gather:
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
logits = tensor_model_parallel_all_gather(logits)
|
||||||
|
|
||||||
# Compute the normalized logprobs for the requested tokens.
|
|
||||||
# Note that we pad a zero at the end for easy batching.
|
|
||||||
logits = logits[:, : self.config.vocab_size].float()
|
logits = logits[:, : self.config.vocab_size].float()
|
||||||
|
|
||||||
if self.final_logit_softcapping:
|
if self.final_logit_softcapping:
|
||||||
@@ -246,27 +235,6 @@ class LogitsProcessor(nn.Module):
|
|||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _get_normalized_prompt_logprobs(
|
|
||||||
input_token_logprobs: torch.Tensor,
|
|
||||||
logits_metadata: LogitsMetadata,
|
|
||||||
):
|
|
||||||
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
|
||||||
pruned_lens = torch.tensor(
|
|
||||||
logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
start = torch.zeros_like(pruned_lens)
|
|
||||||
start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
|
|
||||||
end = torch.clamp(
|
|
||||||
start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
|
|
||||||
)
|
|
||||||
sum_logp = (
|
|
||||||
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
|
||||||
)
|
|
||||||
normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
|
|
||||||
return normalized_prompt_logprobs
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
|
||||||
max_k = max(logits_metadata.top_logprobs_nums)
|
max_k = max(logits_metadata.top_logprobs_nums)
|
||||||
|
|||||||
@@ -191,7 +191,6 @@ class DetokenizerManager:
|
|||||||
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
|
||||||
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
output_top_logprobs_val=recv_obj.output_top_logprobs_val,
|
||||||
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
|
||||||
normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -340,7 +340,6 @@ class BatchTokenIDOut:
|
|||||||
input_top_logprobs_idx: List[List]
|
input_top_logprobs_idx: List[List]
|
||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
normalized_prompt_logprob: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -366,7 +365,6 @@ class BatchStrOut:
|
|||||||
input_top_logprobs_idx: List[List]
|
input_top_logprobs_idx: List[List]
|
||||||
output_top_logprobs_val: List[List]
|
output_top_logprobs_val: List[List]
|
||||||
output_top_logprobs_idx: List[List]
|
output_top_logprobs_idx: List[List]
|
||||||
normalized_prompt_logprob: List[float]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -280,7 +280,6 @@ class Req:
|
|||||||
self.top_logprobs_num = top_logprobs_num
|
self.top_logprobs_num = top_logprobs_num
|
||||||
|
|
||||||
# Logprobs (return value)
|
# Logprobs (return value)
|
||||||
self.normalized_prompt_logprob = None
|
|
||||||
self.input_token_logprobs_val = None
|
self.input_token_logprobs_val = None
|
||||||
self.input_token_logprobs_idx = None
|
self.input_token_logprobs_idx = None
|
||||||
self.input_top_logprobs_val = None
|
self.input_top_logprobs_val = None
|
||||||
@@ -344,9 +343,6 @@ class Req:
|
|||||||
max_prefix_len = min(max_prefix_len, input_len - 1)
|
max_prefix_len = min(max_prefix_len, input_len - 1)
|
||||||
|
|
||||||
if self.return_logprob:
|
if self.return_logprob:
|
||||||
if self.normalized_prompt_logprob is None:
|
|
||||||
# Need at least two tokens to compute normalized logprob
|
|
||||||
max_prefix_len = min(max_prefix_len, input_len - 2)
|
|
||||||
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
max_prefix_len = min(max_prefix_len, self.logprob_start_len)
|
||||||
|
|
||||||
max_prefix_len = max(max_prefix_len, 0)
|
max_prefix_len = max(max_prefix_len, 0)
|
||||||
|
|||||||
@@ -433,7 +433,6 @@ class PrefillAdder:
|
|||||||
or input_tokens <= self.rem_chunk_tokens
|
or input_tokens <= self.rem_chunk_tokens
|
||||||
or (
|
or (
|
||||||
req.return_logprob
|
req.return_logprob
|
||||||
and req.normalized_prompt_logprob is None
|
|
||||||
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
and req.logprob_start_len != len(req.origin_input_ids) - 1
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -1038,9 +1038,6 @@ class Scheduler:
|
|||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
logits_output.normalized_prompt_logprobs = (
|
|
||||||
logits_output.normalized_prompt_logprobs.tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check finish conditions
|
# Check finish conditions
|
||||||
logprob_pt = 0
|
logprob_pt = 0
|
||||||
@@ -1188,9 +1185,6 @@ class Scheduler:
|
|||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
|
||||||
|
|
||||||
if req.normalized_prompt_logprob is None:
|
|
||||||
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
|
||||||
|
|
||||||
if req.input_token_logprobs_val is None:
|
if req.input_token_logprobs_val is None:
|
||||||
input_token_logprobs_val = output.input_token_logprobs[
|
input_token_logprobs_val = output.input_token_logprobs[
|
||||||
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
|
||||||
@@ -1288,15 +1282,12 @@ class Scheduler:
|
|||||||
input_top_logprobs_idx = []
|
input_top_logprobs_idx = []
|
||||||
output_top_logprobs_val = []
|
output_top_logprobs_val = []
|
||||||
output_top_logprobs_idx = []
|
output_top_logprobs_idx = []
|
||||||
normalized_prompt_logprob = []
|
|
||||||
else:
|
else:
|
||||||
input_token_logprobs_val = input_token_logprobs_idx = (
|
input_token_logprobs_val = input_token_logprobs_idx = (
|
||||||
output_token_logprobs_val
|
output_token_logprobs_val
|
||||||
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
) = output_token_logprobs_idx = input_top_logprobs_val = (
|
||||||
input_top_logprobs_idx
|
input_top_logprobs_idx
|
||||||
) = output_top_logprobs_val = output_top_logprobs_idx = (
|
) = output_top_logprobs_val = output_top_logprobs_idx = None
|
||||||
normalized_prompt_logprob
|
|
||||||
) = None
|
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
if req is skip_req:
|
if req is skip_req:
|
||||||
@@ -1343,7 +1334,6 @@ class Scheduler:
|
|||||||
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
input_top_logprobs_idx.append(req.input_top_logprobs_idx)
|
||||||
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
output_top_logprobs_val.append(req.output_top_logprobs_val)
|
||||||
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
output_top_logprobs_idx.append(req.output_top_logprobs_idx)
|
||||||
normalized_prompt_logprob.append(req.normalized_prompt_logprob)
|
|
||||||
|
|
||||||
# Send to detokenizer
|
# Send to detokenizer
|
||||||
if rids:
|
if rids:
|
||||||
@@ -1370,7 +1360,6 @@ class Scheduler:
|
|||||||
input_top_logprobs_idx,
|
input_top_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
output_top_logprobs_idx,
|
output_top_logprobs_idx,
|
||||||
normalized_prompt_logprob,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
|
|||||||
@@ -796,9 +796,6 @@ class TokenizerManager:
|
|||||||
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
||||||
return_text_in_logprobs,
|
return_text_in_logprobs,
|
||||||
)
|
)
|
||||||
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
|
||||||
recv_obj_index
|
|
||||||
]
|
|
||||||
|
|
||||||
if top_logprobs_num > 0:
|
if top_logprobs_num > 0:
|
||||||
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
||||||
|
|||||||
@@ -151,11 +151,6 @@ class TpModelWorkerClient:
|
|||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||||
)
|
)
|
||||||
logits_output.normalized_prompt_logprobs = (
|
|
||||||
logits_output.normalized_prompt_logprobs.to(
|
|
||||||
"cpu", non_blocking=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
||||||
copy_done.record()
|
copy_done.record()
|
||||||
|
|
||||||
@@ -174,9 +169,6 @@ class TpModelWorkerClient:
|
|||||||
logits_output.input_token_logprobs = (
|
logits_output.input_token_logprobs = (
|
||||||
logits_output.input_token_logprobs.tolist()
|
logits_output.input_token_logprobs.tolist()
|
||||||
)
|
)
|
||||||
logits_output.normalized_prompt_logprobs = (
|
|
||||||
logits_output.normalized_prompt_logprobs.tolist()
|
|
||||||
)
|
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
|
|||||||
@@ -535,7 +535,7 @@ def test_hellaswag_select():
|
|||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels))
|
||||||
assert np.abs(accuracy_gen - accuracy) < 0.01
|
assert np.abs(accuracy_gen - accuracy) < 0.05
|
||||||
assert np.abs(latency_gen - latency) < 1
|
assert np.abs(latency_gen - latency) < 1
|
||||||
|
|
||||||
return accuracy, latency
|
return accuracy, latency
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""
|
|
||||||
Usage:
|
|
||||||
python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache
|
|
||||||
|
|
||||||
python3 test_httpserver_classify.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_deprecated(url: str, prompt: str):
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompt,
|
|
||||||
"sampling_params": {
|
|
||||||
"max_new_tokens": 0,
|
|
||||||
},
|
|
||||||
"return_logprob": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return response.json()["meta_info"]["normalized_prompt_logprob"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_batch_deprecated(url: str, prompts: list[str]):
|
|
||||||
response = requests.post(
|
|
||||||
url + "/generate",
|
|
||||||
json={
|
|
||||||
"text": prompts,
|
|
||||||
"sampling_params": {
|
|
||||||
"max_new_tokens": 0,
|
|
||||||
},
|
|
||||||
"return_logprob": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
ret = response.json()
|
|
||||||
logits = np.array(
|
|
||||||
list(
|
|
||||||
ret[i]["meta_info"]["normalized_prompt_logprob"]
|
|
||||||
for i in range(len(prompts))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits(url: str, prompt: str):
|
|
||||||
response = requests.post(
|
|
||||||
url + "/classify",
|
|
||||||
json={"text": prompt},
|
|
||||||
)
|
|
||||||
return response.json()["embedding"]
|
|
||||||
|
|
||||||
|
|
||||||
def get_logits_batch(url: str, prompts: list[str]):
|
|
||||||
response = requests.post(
|
|
||||||
url + "/classify",
|
|
||||||
json={"text": prompts},
|
|
||||||
)
|
|
||||||
return np.array([x["embedding"] for x in response.json()])
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
|
||||||
parser.add_argument("--port", type=int, default=30000)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
|
||||||
|
|
||||||
# A single request
|
|
||||||
prompt = "This is a test prompt.<|eot_id|>"
|
|
||||||
logits = get_logits(url, prompt)
|
|
||||||
print(f"{logits=}")
|
|
||||||
|
|
||||||
# A batch of requests
|
|
||||||
prompts = [
|
|
||||||
"This is a test prompt.<|eot_id|>",
|
|
||||||
"This is another test prompt.<|eot_id|>",
|
|
||||||
"This is a long long long long test prompt.<|eot_id|>",
|
|
||||||
]
|
|
||||||
logits = get_logits_batch(url, prompts)
|
|
||||||
print(f"{logits=}")
|
|
||||||
@@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num):
|
|||||||
if return_logprob:
|
if return_logprob:
|
||||||
assert data["meta_info"]["input_token_logprobs"] is not None
|
assert data["meta_info"]["input_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["output_token_logprobs"] is not None
|
assert data["meta_info"]["output_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
|
||||||
for logprob, token_id, token_text in data["meta_info"][
|
for logprob, token_id, token_text in data["meta_info"][
|
||||||
"output_token_logprobs"
|
"output_token_logprobs"
|
||||||
][prev:]:
|
][prev:]:
|
||||||
|
|||||||
Reference in New Issue
Block a user