Logprobs Refractor (#331)
This commit is contained in:
@@ -14,10 +14,14 @@ class GenerateReqInput:
|
|||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
# The request id
|
# The request id
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
# Whether return logprobs of the prompts
|
# Whether to return logprobs
|
||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob
|
# The start location of the prompt for return_logprob
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
|
# The number of top logprobs to return
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
|
# Whether to detokenize tokens in logprobs
|
||||||
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ Usage:
|
|||||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
python choices_logprob.py
|
python choices_logprob.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
@@ -19,9 +20,9 @@ def main():
|
|||||||
print("questions:", question)
|
print("questions:", question)
|
||||||
print("choice:", state["tool"])
|
print("choice:", state["tool"])
|
||||||
meta_info = state.get_meta_info("tool")
|
meta_info = state.get_meta_info("tool")
|
||||||
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
|
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
||||||
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
|
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
||||||
print('-' * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
# Run a batch
|
# Run a batch
|
||||||
questions = [
|
questions = [
|
||||||
@@ -33,9 +34,9 @@ def main():
|
|||||||
print("questions:", question)
|
print("questions:", question)
|
||||||
print("choice:", state["tool"])
|
print("choice:", state["tool"])
|
||||||
meta_info = state.get_meta_info("tool")
|
meta_info = state.get_meta_info("tool")
|
||||||
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
|
print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0])
|
||||||
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
|
print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1])
|
||||||
print('-' * 50)
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
"sampling_params": {"max_new_tokens": 0},
|
"sampling_params": {"max_new_tokens": 0},
|
||||||
"return_logprob": True,
|
"return_logprob": True,
|
||||||
"logprob_start_len": max(prompt_len - 2, 0),
|
"logprob_start_len": max(prompt_len - 2, 0),
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
}
|
}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(
|
res = http_request(
|
||||||
@@ -224,13 +225,19 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
)
|
)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
obj = res.json()
|
obj = res.json()
|
||||||
normalized_prompt_logprob = [
|
normalized_prompt_logprobs = [
|
||||||
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||||
]
|
]
|
||||||
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
|
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
||||||
|
prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj]
|
||||||
|
decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj]
|
||||||
|
|
||||||
decision = choices[np.argmax(normalized_prompt_logprob)]
|
return (
|
||||||
return decision, normalized_prompt_logprob, prompt_logprob
|
decision,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
|
prefill_token_logprobs,
|
||||||
|
decode_token_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||||
res = http_request(
|
res = http_request(
|
||||||
|
|||||||
@@ -454,15 +454,19 @@ class StreamExecutor:
|
|||||||
self.stream_var_event[name].set()
|
self.stream_var_event[name].set()
|
||||||
|
|
||||||
def _execute_select(self, expr: SglSelect):
|
def _execute_select(self, expr: SglSelect):
|
||||||
decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
|
(
|
||||||
self, expr.choices, expr.temperature
|
decision,
|
||||||
)
|
normalized_prompt_logprobs,
|
||||||
|
prefill_token_logprobs,
|
||||||
|
decode_token_logprobs,
|
||||||
|
) = self.backend.select(self, expr.choices, expr.temperature)
|
||||||
if expr.name is not None:
|
if expr.name is not None:
|
||||||
name = expr.name
|
name = expr.name
|
||||||
self.variables[name] = decision
|
self.variables[name] = decision
|
||||||
self.meta_info[name] = {
|
self.meta_info[name] = {
|
||||||
"normalized_prompt_logprob": normalized_prompt_logprob,
|
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
||||||
"prompt_logprob": prompt_logprob,
|
"prefill_token_logprobs": prefill_token_logprobs,
|
||||||
|
"decode_token_logprobs": decode_token_logprobs,
|
||||||
}
|
}
|
||||||
self.variable_event[name].set()
|
self.variable_event[name].set()
|
||||||
self.text_ += decision
|
self.text_ += decision
|
||||||
|
|||||||
@@ -13,76 +13,127 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
def _get_normalized_prompt_logprobs(
|
||||||
last_index = None
|
self, prefill_token_logprobs, input_metadata: InputMetadata
|
||||||
|
):
|
||||||
|
logprobs_cumsum = torch.cumsum(
|
||||||
|
prefill_token_logprobs, dim=0, dtype=torch.float32
|
||||||
|
)
|
||||||
|
|
||||||
# Compute the last index (the first decode token) of each requeast
|
start = input_metadata.extend_start_loc.clone()
|
||||||
# if we are in prefill or extend mode.
|
end = start + input_metadata.extend_seq_lens - 2
|
||||||
|
start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||||
|
end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
|
||||||
|
sum_logp = (
|
||||||
|
logprobs_cumsum[end]
|
||||||
|
- logprobs_cumsum[start]
|
||||||
|
+ prefill_token_logprobs[start]
|
||||||
|
)
|
||||||
|
normalized_prompt_logprobs = sum_logp / (
|
||||||
|
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized_prompt_logprobs
|
||||||
|
|
||||||
|
def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
|
||||||
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
decode_top_logprobs = []
|
||||||
|
for i in range(all_logprobs.shape[0]):
|
||||||
|
k = input_metadata.top_logprobs_nums[i]
|
||||||
|
t = all_logprobs[i].topk(k)
|
||||||
|
v_cpu = t.values.cpu().tolist()
|
||||||
|
p_cpu = t.indices.cpu().tolist()
|
||||||
|
decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
||||||
|
return None, decode_top_logprobs
|
||||||
|
else:
|
||||||
|
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||||
|
pt = 0
|
||||||
|
# NOTE: the GPU-CPU overhead can be reduced
|
||||||
|
extend_seq_lens_cpu = input_metadata.extend_seq_lens
|
||||||
|
for i in range(len(input_metadata.extend_seq_lens)):
|
||||||
|
if extend_seq_lens_cpu[i] == 0:
|
||||||
|
continue
|
||||||
|
k = input_metadata.top_logprobs_nums[i]
|
||||||
|
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||||
|
vs_cpu = t.values.cpu().tolist()
|
||||||
|
ps_cpu = t.indices.cpu().tolist()
|
||||||
|
prefill_top_logprobs.append(
|
||||||
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||||
|
)
|
||||||
|
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||||
|
return prefill_top_logprobs, decode_top_logprobs
|
||||||
|
|
||||||
|
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||||
|
# Get last index for next token prediction, except for DECODE mode.
|
||||||
|
last_index = None
|
||||||
if input_metadata.forward_mode != ForwardMode.DECODE:
|
if input_metadata.forward_mode != ForwardMode.DECODE:
|
||||||
last_index = (
|
last_index = (
|
||||||
torch.cumsum(
|
torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
|
||||||
input_metadata.seq_lens - input_metadata.prefix_lens,
|
|
||||||
dim=0,
|
|
||||||
dtype=torch.long,
|
|
||||||
)
|
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
|
|
||||||
if not input_metadata.return_logprob:
|
# Get the last hidden states and last logits
|
||||||
# When logprob is not requested, only compute the last logits.
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
last_hidden = hidden_states[last_index]
|
last_hidden = hidden_states[last_index]
|
||||||
hidden_states = None
|
|
||||||
|
|
||||||
last_logits = torch.matmul(last_hidden, weight.T)
|
last_logits = torch.matmul(last_hidden, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
last_logits = last_logits[:, : self.config.vocab_size]
|
last_logits = last_logits[:, : self.config.vocab_size]
|
||||||
return last_logits, (None, None, None)
|
|
||||||
|
# Return only last_logits if logprob is not requested
|
||||||
|
if not input_metadata.return_logprob:
|
||||||
|
hidden_states = None
|
||||||
|
return last_logits, (None, None, None, None, None)
|
||||||
else:
|
else:
|
||||||
# When logprob is requested, compute the logits for all tokens.
|
# When logprob is requested, compute the logits for all tokens.
|
||||||
logits = torch.matmul(hidden_states, weight.T)
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
|
all_logits = last_logits
|
||||||
|
else:
|
||||||
|
all_logits = torch.matmul(hidden_states, weight.T)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
logits = tensor_model_parallel_all_gather(logits)
|
all_logits = tensor_model_parallel_all_gather(all_logits)
|
||||||
logits = logits[:, : self.config.vocab_size]
|
all_logits = all_logits[:, : self.config.vocab_size]
|
||||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
|
||||||
|
all_logprobs = torch.log(torch.softmax(all_logits.float(), dim=-1) + 1e-6)
|
||||||
|
|
||||||
|
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
||||||
|
all_logprobs, input_metadata
|
||||||
|
)
|
||||||
|
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_logits = logits
|
|
||||||
last_logprobs = all_logprobs
|
last_logprobs = all_logprobs
|
||||||
prefill_logprobs = normalized_logprobs = None
|
return last_logits, (
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
decode_top_logprobs,
|
||||||
|
None,
|
||||||
|
last_logprobs,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Compute the logprobs for the last token of each request.
|
# Compute the logprobs for the last token of each request.
|
||||||
last_logits = logits[last_index]
|
|
||||||
last_logprobs = all_logprobs[last_index]
|
last_logprobs = all_logprobs[last_index]
|
||||||
|
|
||||||
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
||||||
# Note that we pad a zero at the end of each sequence for easy computation.
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
||||||
prefill_logprobs = all_logprobs[
|
prefill_token_logprobs = all_logprobs[
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
]
|
]
|
||||||
logprobs_cumsum = torch.cumsum(
|
|
||||||
prefill_logprobs, dim=0, dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
start = input_metadata.extend_start_loc.clone()
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
||||||
end = start + input_metadata.extend_seq_lens - 2
|
prefill_token_logprobs, input_metadata
|
||||||
start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
|
||||||
end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
|
|
||||||
sum_logp = (
|
|
||||||
logprobs_cumsum[end]
|
|
||||||
- logprobs_cumsum[start]
|
|
||||||
+ prefill_logprobs[start]
|
|
||||||
)
|
)
|
||||||
normalized_logprobs = sum_logp / (
|
return last_logits, (
|
||||||
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs,
|
||||||
|
decode_top_logprobs,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
|
last_logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
all_logprobs = torch.tensor(
|
all_logprobs = torch.tensor(
|
||||||
@@ -93,23 +144,22 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
|
||||||
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
|
||||||
logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
logprobs = all_logprobs[
|
token_logprobs = all_logprobs[
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
]
|
]
|
||||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
|
||||||
|
|
||||||
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
len_cumsum = torch.cumsum(seq_lens, dim=0)
|
||||||
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
|
||||||
end = start + seq_lens - 2
|
end = start + seq_lens - 2
|
||||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
||||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
|
||||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
|
||||||
|
|
||||||
# assert logprobs == [2, _, 2, 4, _]
|
# assert logprobs == [2, _, 2, 4, _]
|
||||||
print("logprobs", logprobs)
|
print("token logprobs", token_logprobs)
|
||||||
print("start", start)
|
print("start", start)
|
||||||
print("end", end)
|
print("end", end)
|
||||||
print("sum_logp", sum_logp)
|
print("sum_logp", sum_logp)
|
||||||
|
|||||||
@@ -19,10 +19,13 @@ class GenerateReqInput:
|
|||||||
return_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
# The start location of the prompt for return_logprob
|
# The start location of the prompt for return_logprob
|
||||||
logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
|
# The number of top logprobs to return
|
||||||
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
||||||
# Whether to detokenize tokens in logprobs
|
# Whether to detokenize tokens in logprobs
|
||||||
return_text_in_logprobs: bool = False
|
return_text_in_logprobs: bool = False
|
||||||
# Whether to stream output
|
# Whether to stream output
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
# TODO: make all parameters a Union[List[T], T] to allow for batched requests
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
is_single = isinstance(self.text, str)
|
is_single = isinstance(self.text, str)
|
||||||
@@ -36,6 +39,8 @@ class GenerateReqInput:
|
|||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
if self.logprob_start_len is None:
|
if self.logprob_start_len is None:
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
|
if self.top_logprobs_num is None:
|
||||||
|
self.top_logprobs_num = 0
|
||||||
else:
|
else:
|
||||||
num = len(self.text)
|
num = len(self.text)
|
||||||
|
|
||||||
@@ -64,6 +69,11 @@ class GenerateReqInput:
|
|||||||
elif not isinstance(self.logprob_start_len, list):
|
elif not isinstance(self.logprob_start_len, list):
|
||||||
self.logprob_start_len = [self.logprob_start_len] * num
|
self.logprob_start_len = [self.logprob_start_len] * num
|
||||||
|
|
||||||
|
if self.top_logprobs_num is None:
|
||||||
|
self.top_logprobs_num = [0] * num
|
||||||
|
elif not isinstance(self.top_logprobs_num, list):
|
||||||
|
self.top_logprobs_num = [self.top_logprobs_num] * num
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TokenizedGenerateReqInput:
|
class TokenizedGenerateReqInput:
|
||||||
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
|
|||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
logprob_start_len: int
|
logprob_start_len: int
|
||||||
|
top_logprobs_num: int
|
||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class Req:
|
|||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
self.logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
|
self.top_logprobs_num = 0
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@@ -54,9 +55,11 @@ class Req:
|
|||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
self.logprob = None
|
self.prefill_token_logprobs = None
|
||||||
self.token_logprob = None
|
self.decode_token_logprobs = None
|
||||||
self.normalized_logprob = None
|
self.normalized_prompt_logprob = None
|
||||||
|
self.prefill_top_logprobs = None
|
||||||
|
self.decode_top_logprobs = None
|
||||||
|
|
||||||
# For constrained decoding
|
# For constrained decoding
|
||||||
self.regex_fsm = None
|
self.regex_fsm = None
|
||||||
@@ -159,6 +162,9 @@ class Batch:
|
|||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
out_cache_cont_start: torch.Tensor = None
|
out_cache_cont_start: torch.Tensor = None
|
||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
|
# for processing logprobs
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
|
||||||
# for multimodal
|
# for multimodal
|
||||||
@@ -266,6 +272,7 @@ class Batch:
|
|||||||
self.position_ids_offsets = position_ids_offsets
|
self.position_ids_offsets = position_ids_offsets
|
||||||
self.extend_num_tokens = extend_num_tokens
|
self.extend_num_tokens = extend_num_tokens
|
||||||
self.out_cache_loc = out_cache_loc
|
self.out_cache_loc = out_cache_loc
|
||||||
|
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||||
|
|
||||||
self.temperatures = torch.tensor(
|
self.temperatures = torch.tensor(
|
||||||
[r.sampling_params.temperature for r in reqs],
|
[r.sampling_params.temperature for r in reqs],
|
||||||
@@ -415,6 +422,7 @@ class Batch:
|
|||||||
self.prefix_lens = None
|
self.prefix_lens = None
|
||||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||||
|
self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
@@ -439,6 +447,7 @@ class Batch:
|
|||||||
[self.position_ids_offsets, other.position_ids_offsets]
|
[self.position_ids_offsets, other.position_ids_offsets]
|
||||||
)
|
)
|
||||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||||
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
||||||
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
|
|||||||
@@ -260,6 +260,7 @@ class ModelRpcServer:
|
|||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.logprob_start_len = recv_req.logprob_start_len
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
@@ -400,28 +401,36 @@ class ModelRpcServer:
|
|||||||
self.model_config.vocab_size, self.int_token_logit_bias
|
self.model_config.vocab_size, self.int_token_logit_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
logprobs = None
|
prefill_token_logprobs = None
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, (
|
logits, (
|
||||||
prefill_logprobs,
|
prefill_token_logprobs,
|
||||||
normalized_logprobs,
|
prefill_top_logprobs,
|
||||||
|
decode_top_logprobs,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
last_logprobs,
|
last_logprobs,
|
||||||
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
) = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
||||||
if prefill_logprobs is not None:
|
if prefill_token_logprobs is not None:
|
||||||
logprobs = prefill_logprobs.cpu().tolist()
|
prefill_token_logprobs = prefill_token_logprobs.cpu().tolist()
|
||||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
normalized_prompt_logprobs = normalized_prompt_logprobs.cpu().tolist()
|
||||||
|
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
else:
|
else:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
logits = logprobs = normalized_logprobs = last_logprobs = None
|
(
|
||||||
|
logits,
|
||||||
|
prefill_token_logprobs,
|
||||||
|
normalized_prompt_logprobs,
|
||||||
|
last_logprobs,
|
||||||
|
) = (None,) * 4
|
||||||
|
|
||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
|
last_token_logprobs = None
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
last_logprobs = (
|
last_token_logprobs = (
|
||||||
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -432,18 +441,26 @@ class ModelRpcServer:
|
|||||||
req.output_ids = [next_token_ids[i]]
|
req.output_ids = [next_token_ids[i]]
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if logprobs is not None:
|
if prefill_token_logprobs is not None:
|
||||||
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||||
req.normalized_logprob = normalized_logprobs[i]
|
req.prefill_token_logprobs = list(
|
||||||
|
zip(
|
||||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens
|
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||||
# will be ignored.
|
req.input_ids[-req.extend_input_len + 1 :],
|
||||||
prompt_token_len = len(req.logprob)
|
)
|
||||||
token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
|
)
|
||||||
token_logprobs = req.logprob + [last_logprobs[i]]
|
|
||||||
req.token_logprob = list(zip(token_ids, token_logprobs))
|
|
||||||
if req.logprob_start_len == 0:
|
if req.logprob_start_len == 0:
|
||||||
req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
|
req.prefill_token_logprobs = [
|
||||||
|
(None, req.input_ids[0])
|
||||||
|
] + req.prefill_token_logprobs
|
||||||
|
req.decode_token_logprobs = [
|
||||||
|
(last_token_logprobs[i], next_token_ids[i])
|
||||||
|
]
|
||||||
|
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||||
|
if req.logprob_start_len == 0:
|
||||||
|
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||||
|
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
||||||
|
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||||
pt += req.extend_input_len
|
pt += req.extend_input_len
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
@@ -493,27 +510,29 @@ class ModelRpcServer:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
# Forward
|
# Forward
|
||||||
logits, (_, _, last_logprobs) = self.model_runner.forward(
|
logits, (_, _, decode_top_logprobs, _, last_logprobs) = (
|
||||||
batch, ForwardMode.DECODE
|
self.model_runner.forward(batch, ForwardMode.DECODE)
|
||||||
)
|
)
|
||||||
next_token_ids, _ = batch.sample(logits)
|
next_token_ids, _ = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
|
||||||
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
# Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
|
new_token_logprobs = None
|
||||||
if last_logprobs is not None:
|
if last_logprobs is not None:
|
||||||
last_logprobs = last_logprobs[
|
new_token_logprobs = last_logprobs[
|
||||||
torch.arange(len(reqs)), next_token_ids
|
torch.arange(len(reqs)), next_token_ids
|
||||||
].tolist()
|
].tolist()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(reqs, next_token_ids)):
|
||||||
req.completion_tokens_wo_jump_forward += 1
|
req.completion_tokens_wo_jump_forward += 1
|
||||||
req.output_ids.append(next_tok_id)
|
req.output_ids.append(next_token_id)
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if last_logprobs is not None:
|
if new_token_logprobs is not None:
|
||||||
req.token_logprob.append((next_tok_id, last_logprobs[i]))
|
req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
|
||||||
|
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
@@ -558,9 +577,19 @@ class ModelRpcServer:
|
|||||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||||
}
|
}
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
meta_info["prompt_logprob"] = req.logprob
|
(
|
||||||
meta_info["token_logprob"] = req.token_logprob
|
meta_info["prefill_token_logprobs"],
|
||||||
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
meta_info["decode_token_logprobs"],
|
||||||
|
meta_info["prefill_top_logprobs"],
|
||||||
|
meta_info["decode_top_logprobs"],
|
||||||
|
meta_info["normalized_prompt_logprob"],
|
||||||
|
) = (
|
||||||
|
req.prefill_token_logprobs,
|
||||||
|
req.decode_token_logprobs,
|
||||||
|
req.prefill_top_logprobs,
|
||||||
|
req.decode_top_logprobs,
|
||||||
|
req.normalized_prompt_logprob,
|
||||||
|
)
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
output_finished.append(req.finished)
|
output_finished.append(req.finished)
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import logging
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -81,6 +82,7 @@ class InputMetadata:
|
|||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
other_kv_index: torch.Tensor = None
|
other_kv_index: torch.Tensor = None
|
||||||
|
top_logprobs_nums: List[int] = None
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
|
||||||
# for flashinfer
|
# for flashinfer
|
||||||
@@ -181,6 +183,7 @@ class InputMetadata:
|
|||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
out_cache_cont_start=None,
|
out_cache_cont_start=None,
|
||||||
out_cache_cont_end=None,
|
out_cache_cont_end=None,
|
||||||
|
top_logprobs_nums=None,
|
||||||
return_logprob=False,
|
return_logprob=False,
|
||||||
):
|
):
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
@@ -229,6 +232,7 @@ class InputMetadata:
|
|||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
|
top_logprobs_nums=top_logprobs_nums,
|
||||||
return_logprob=return_logprob,
|
return_logprob=return_logprob,
|
||||||
other_kv_index=other_kv_index,
|
other_kv_index=other_kv_index,
|
||||||
)
|
)
|
||||||
@@ -377,6 +381,7 @@ class ModelRunner:
|
|||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -394,6 +399,7 @@ class ModelRunner:
|
|||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -413,6 +419,7 @@ class ModelRunner:
|
|||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
out_cache_cont_start=batch.out_cache_cont_start,
|
out_cache_cont_start=batch.out_cache_cont_start,
|
||||||
out_cache_cont_end=batch.out_cache_cont_end,
|
out_cache_cont_end=batch.out_cache_cont_end,
|
||||||
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
@@ -430,6 +437,7 @@ class ModelRunner:
|
|||||||
prefix_lens=batch.prefix_lens,
|
prefix_lens=batch.prefix_lens,
|
||||||
position_ids_offsets=batch.position_ids_offsets,
|
position_ids_offsets=batch.position_ids_offsets,
|
||||||
out_cache_loc=batch.out_cache_loc,
|
out_cache_loc=batch.out_cache_loc,
|
||||||
|
top_logprobs_nums=batch.top_logprobs_nums,
|
||||||
return_logprob=batch.return_logprob,
|
return_logprob=batch.return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
|
|||||||
@@ -173,6 +173,7 @@ class TokenizerManager:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_logprob=obj.return_logprob,
|
return_logprob=obj.return_logprob,
|
||||||
logprob_start_len=obj.logprob_start_len,
|
logprob_start_len=obj.logprob_start_len,
|
||||||
|
top_logprobs_num=obj.top_logprobs_num,
|
||||||
stream=obj.stream,
|
stream=obj.stream,
|
||||||
)
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
@@ -215,6 +216,7 @@ class TokenizerManager:
|
|||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_logprob=obj.return_logprob[i],
|
return_logprob=obj.return_logprob[i],
|
||||||
logprob_start_len=obj.logprob_start_len[i],
|
logprob_start_len=obj.logprob_start_len[i],
|
||||||
|
top_logprobs_num=obj.top_logprobs_num[i],
|
||||||
stream=obj.stream,
|
stream=obj.stream,
|
||||||
)
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|||||||
@@ -123,31 +123,97 @@ async def flush_cache():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def detokenize_logprob_tokens(token_logprobs):
|
async def detokenize_logprob_tokens(token_logprobs, decode_to_text):
|
||||||
token_ids = [tid for tid, _ in token_logprobs]
|
if not decode_to_text:
|
||||||
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||||
|
|
||||||
|
token_ids = [tid for _, tid in token_logprobs]
|
||||||
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
token_texts = await tokenizer_manager.detokenize(DetokenizeReqInput(token_ids))
|
||||||
return [(text, logprob) for text, (_, logprob) in zip(token_texts, token_logprobs)]
|
return [
|
||||||
|
(logprob, token_id, token_text)
|
||||||
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def detokenize_top_logprobs_tokens(top_logprobs, decode_to_text):
|
||||||
|
for i, t in enumerate(top_logprobs):
|
||||||
|
if top_logprobs[i] is not None:
|
||||||
|
top_logprobs[i] = await detokenize_logprob_tokens(t, decode_to_text)
|
||||||
|
return top_logprobs
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_token_logprobs_results(obj: GenerateReqInput, ret):
|
||||||
|
"""Handle the token logprobs results, convert token ids to text if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj (GenerateReqInput): The request object.
|
||||||
|
ret (Union[Dict, List[Dict]]): The response object.
|
||||||
|
"""
|
||||||
|
# NOTE: This is because the multiple requests in one http request.
|
||||||
|
|
||||||
|
async def convert_style(r, return_text):
|
||||||
|
r["meta_info"]["prefill_token_logprobs"] = await detokenize_logprob_tokens(
|
||||||
|
r["meta_info"]["prefill_token_logprobs"], return_text
|
||||||
|
)
|
||||||
|
r["meta_info"]["decode_token_logprobs"] = await detokenize_logprob_tokens(
|
||||||
|
r["meta_info"]["decode_token_logprobs"], return_text
|
||||||
|
)
|
||||||
|
r["meta_info"]["prefill_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
||||||
|
r["meta_info"]["prefill_top_logprobs"], return_text
|
||||||
|
)
|
||||||
|
r["meta_info"]["decode_top_logprobs"] = await detokenize_top_logprobs_tokens(
|
||||||
|
r["meta_info"]["decode_top_logprobs"], return_text
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(obj.text, str):
|
||||||
|
if obj.return_logprob:
|
||||||
|
await convert_style(ret, obj.return_text_in_logprobs)
|
||||||
|
else:
|
||||||
|
for i, r in enumerate(ret):
|
||||||
|
if obj.return_logprob[i]:
|
||||||
|
await convert_style(r, obj.return_text_in_logprobs)
|
||||||
|
|
||||||
|
|
||||||
async def stream_generator(obj: GenerateReqInput):
|
async def stream_generator(obj: GenerateReqInput):
|
||||||
async for out in tokenizer_manager.generate_request(obj):
|
async for out in tokenizer_manager.generate_request(obj):
|
||||||
if obj.return_logprob and obj.return_text_in_logprobs:
|
await handle_token_logprobs_results(obj, out)
|
||||||
out["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
|
||||||
out["meta_info"]["token_logprob"]
|
|
||||||
)
|
|
||||||
yield out
|
yield out
|
||||||
|
|
||||||
|
|
||||||
async def make_openai_style_logprobs(token_logprobs):
|
async def make_openai_style_logprobs(
|
||||||
|
prefill_token_logprobs=None,
|
||||||
|
decode_token_logprobs=None,
|
||||||
|
prefill_top_logprobs=None,
|
||||||
|
decode_top_logprobs=None,
|
||||||
|
):
|
||||||
ret_logprobs = LogProbs()
|
ret_logprobs = LogProbs()
|
||||||
|
|
||||||
for token_text, token_logprob in token_logprobs:
|
def append_token_logprobs(token_logprobs):
|
||||||
|
for logprob, _, token_text in token_logprobs:
|
||||||
ret_logprobs.tokens.append(token_text)
|
ret_logprobs.tokens.append(token_text)
|
||||||
ret_logprobs.token_logprobs.append(token_logprob)
|
ret_logprobs.token_logprobs.append(logprob)
|
||||||
|
|
||||||
# Not supported yet.
|
# Not Supported yet
|
||||||
ret_logprobs.top_logprobs.append({})
|
|
||||||
ret_logprobs.text_offset.append(-1)
|
ret_logprobs.text_offset.append(-1)
|
||||||
|
|
||||||
|
def append_top_logprobs(top_logprobs):
|
||||||
|
for tokens in top_logprobs:
|
||||||
|
if tokens is not None:
|
||||||
|
ret_logprobs.top_logprobs.append(
|
||||||
|
{token[2]: token[0] for token in tokens}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ret_logprobs.top_logprobs.append(None)
|
||||||
|
|
||||||
|
if prefill_token_logprobs is not None:
|
||||||
|
append_token_logprobs(prefill_token_logprobs)
|
||||||
|
if decode_token_logprobs is not None:
|
||||||
|
append_token_logprobs(decode_token_logprobs)
|
||||||
|
if prefill_top_logprobs is not None:
|
||||||
|
append_top_logprobs(prefill_top_logprobs)
|
||||||
|
if decode_top_logprobs is not None:
|
||||||
|
append_top_logprobs(decode_top_logprobs)
|
||||||
|
|
||||||
return ret_logprobs
|
return ret_logprobs
|
||||||
|
|
||||||
|
|
||||||
@@ -165,10 +231,7 @@ async def generate_request(obj: GenerateReqInput):
|
|||||||
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
return StreamingResponse(stream_results(), media_type="text/event-stream")
|
||||||
|
|
||||||
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
ret = await tokenizer_manager.generate_request(obj).__anext__()
|
||||||
if obj.return_logprob and obj.return_text_in_logprobs:
|
await handle_token_logprobs_results(obj, ret)
|
||||||
ret["meta_info"]["token_logprob"] = await detokenize_logprob_tokens(
|
|
||||||
ret["meta_info"]["token_logprob"]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
@@ -192,7 +255,8 @@ async def v1_completions(raw_request: Request):
|
|||||||
"frequency_penalty": request.frequency_penalty,
|
"frequency_penalty": request.frequency_penalty,
|
||||||
"regex": request.regex,
|
"regex": request.regex,
|
||||||
},
|
},
|
||||||
return_logprob=request.logprobs is not None,
|
return_logprob=request.logprobs is not None and request.logprobs > 0,
|
||||||
|
top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
)
|
)
|
||||||
@@ -212,15 +276,32 @@ async def v1_completions(raw_request: Request):
|
|||||||
if request.echo:
|
if request.echo:
|
||||||
# Prepend prompt in response text.
|
# Prepend prompt in response text.
|
||||||
text = request.prompt + text
|
text = request.prompt + text
|
||||||
else:
|
|
||||||
# Skip prompt tokens if echo is disabled.
|
|
||||||
n_prev_token = prompt_tokens
|
|
||||||
|
|
||||||
if request.logprobs is not None:
|
if request.logprobs:
|
||||||
|
# The first chunk and echo is enabled.
|
||||||
|
if not stream_buffer and request.echo:
|
||||||
|
prefill_token_logprobs = content["meta_info"][
|
||||||
|
"prefill_token_logprobs"
|
||||||
|
]
|
||||||
|
prefill_top_logprobs = content["meta_info"][
|
||||||
|
"prefill_top_logprobs"
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
prefill_token_logprobs = None
|
||||||
|
prefill_top_logprobs = None
|
||||||
|
|
||||||
logprobs = await make_openai_style_logprobs(
|
logprobs = await make_openai_style_logprobs(
|
||||||
content["meta_info"]["token_logprob"][n_prev_token:]
|
prefill_token_logprobs=prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs=prefill_top_logprobs,
|
||||||
|
decode_token_logprobs=content["meta_info"][
|
||||||
|
"decode_token_logprobs"
|
||||||
|
][n_prev_token:],
|
||||||
|
decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
|
||||||
|
n_prev_token:
|
||||||
|
],
|
||||||
)
|
)
|
||||||
n_prev_token = len(content["meta_info"]["token_logprob"])
|
|
||||||
|
n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
@@ -255,20 +336,26 @@ async def v1_completions(raw_request: Request):
|
|||||||
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
prompt_tokens = ret["meta_info"]["prompt_tokens"]
|
||||||
completion_tokens = ret["meta_info"]["completion_tokens"]
|
completion_tokens = ret["meta_info"]["completion_tokens"]
|
||||||
text = ret["text"]
|
text = ret["text"]
|
||||||
token_logprob_pos = prompt_tokens
|
|
||||||
if request.echo:
|
if request.echo:
|
||||||
token_logprob_pos = 0
|
|
||||||
text = request.prompt + text
|
text = request.prompt + text
|
||||||
else:
|
|
||||||
token_logprob_pos = prompt_tokens
|
|
||||||
|
|
||||||
logprobs = (
|
if request.logprobs:
|
||||||
await make_openai_style_logprobs(
|
if request.echo:
|
||||||
ret["meta_info"]["token_logprob"][token_logprob_pos:]
|
prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
|
||||||
)
|
prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
|
||||||
if request.logprobs is not None
|
else:
|
||||||
else None
|
prefill_token_logprobs = None
|
||||||
|
prefill_top_logprobs = None
|
||||||
|
|
||||||
|
logprobs = await make_openai_style_logprobs(
|
||||||
|
prefill_token_logprobs=prefill_token_logprobs,
|
||||||
|
prefill_top_logprobs=prefill_top_logprobs,
|
||||||
|
decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
|
||||||
|
decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
choice_data = CompletionResponseChoice(
|
choice_data = CompletionResponseChoice(
|
||||||
index=0,
|
index=0,
|
||||||
text=text,
|
text=text,
|
||||||
|
|||||||
@@ -9,11 +9,12 @@ The capital of France is Paris.\nThe capital of the United States is Washington,
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_decode(url, return_logprob):
|
def test_decode(url, return_logprob, top_logprobs_num, return_text):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -23,10 +24,13 @@ def test_decode(url, return_logprob):
|
|||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
},
|
},
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": return_text,
|
||||||
"logprob_start_len": 0,
|
"logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(response.json())
|
print(json.dumps(response.json()))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -37,5 +41,8 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
test_decode(url, False)
|
test_decode(url, False, 0, False)
|
||||||
test_decode(url, True)
|
test_decode(url, True, 0, False)
|
||||||
|
test_decode(url, True, 0, True)
|
||||||
|
test_decode(url, True, 3, False)
|
||||||
|
test_decode(url, True, 3, True)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import json
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
def test_decode_stream(url, return_logprob):
|
def test_decode_stream(url, return_logprob, top_logprobs_num):
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url + "/generate",
|
url + "/generate",
|
||||||
json={
|
json={
|
||||||
@@ -24,6 +24,8 @@ def test_decode_stream(url, return_logprob):
|
|||||||
},
|
},
|
||||||
"stream": True,
|
"stream": True,
|
||||||
"return_logprob": return_logprob,
|
"return_logprob": return_logprob,
|
||||||
|
"top_logprobs_num": top_logprobs_num,
|
||||||
|
"return_text_in_logprobs": True,
|
||||||
},
|
},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
@@ -37,19 +39,20 @@ def test_decode_stream(url, return_logprob):
|
|||||||
data = json.loads(chunk[5:].strip("\n"))
|
data = json.loads(chunk[5:].strip("\n"))
|
||||||
|
|
||||||
if return_logprob:
|
if return_logprob:
|
||||||
assert data["meta_info"]["prompt_logprob"] is not None
|
assert data["meta_info"]["prefill_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["token_logprob"] is not None
|
assert data["meta_info"]["decode_token_logprobs"] is not None
|
||||||
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
assert data["meta_info"]["normalized_prompt_logprob"] is not None
|
||||||
if prev == 0: # Skip prompt logprobs
|
for logprob, token_id, token_text in data["meta_info"][
|
||||||
prev = data["meta_info"]["prompt_tokens"]
|
"decode_token_logprobs"
|
||||||
for token_txt, _, logprob in data["meta_info"]["token_logprob"][prev:]:
|
][prev:]:
|
||||||
print(f"{token_txt}\t{logprob}", flush=True)
|
print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True)
|
||||||
prev = len(data["meta_info"]["token_logprob"])
|
prev = len(data["meta_info"]["decode_token_logprobs"])
|
||||||
else:
|
else:
|
||||||
output = data["text"].strip()
|
output = data["text"].strip()
|
||||||
print(output[prev:], end="", flush=True)
|
print(output[prev:], end="", flush=True)
|
||||||
prev = len(output)
|
prev = len(output)
|
||||||
print("")
|
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -60,5 +63,6 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
url = f"{args.host}:{args.port}"
|
url = f"{args.host}:{args.port}"
|
||||||
|
|
||||||
test_decode_stream(url, False)
|
test_decode_stream(url, False, 0)
|
||||||
test_decode_stream(url, True)
|
test_decode_stream(url, True, 0)
|
||||||
|
test_decode_stream(url, True, 3)
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ def test_completion(args, echo, logprobs):
|
|||||||
if echo:
|
if echo:
|
||||||
assert text.startswith("The capital of France is")
|
assert text.startswith("The capital of France is")
|
||||||
if logprobs:
|
if logprobs:
|
||||||
|
print(response.choices[0].logprobs.top_logprobs)
|
||||||
assert response.choices[0].logprobs
|
assert response.choices[0].logprobs
|
||||||
if echo:
|
if echo:
|
||||||
assert response.choices[0].logprobs.token_logprobs[0] == None
|
assert response.choices[0].logprobs.token_logprobs[0] == None
|
||||||
@@ -44,6 +45,7 @@ def test_completion(args, echo, logprobs):
|
|||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
def test_completion_stream(args, echo, logprobs):
|
def test_completion_stream(args, echo, logprobs):
|
||||||
@@ -68,13 +70,14 @@ def test_completion_stream(args, echo, logprobs):
|
|||||||
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
f"{r.choices[0].text:12s}\t" f"{r.choices[0].logprobs.token_logprobs}",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
print(r.choices[0].logprobs.top_logprobs)
|
||||||
else:
|
else:
|
||||||
print(r.choices[0].text, end="", flush=True)
|
print(r.choices[0].text, end="", flush=True)
|
||||||
assert r.id
|
assert r.id
|
||||||
assert r.usage.prompt_tokens > 0
|
assert r.usage.prompt_tokens > 0
|
||||||
assert r.usage.completion_tokens > 0
|
assert r.usage.completion_tokens > 0
|
||||||
assert r.usage.total_tokens > 0
|
assert r.usage.total_tokens > 0
|
||||||
print()
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion(args):
|
def test_chat_completion(args):
|
||||||
@@ -94,6 +97,7 @@ def test_chat_completion(args):
|
|||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_image(args):
|
def test_chat_completion_image(args):
|
||||||
@@ -124,6 +128,7 @@ def test_chat_completion_image(args):
|
|||||||
assert response.usage.prompt_tokens > 0
|
assert response.usage.prompt_tokens > 0
|
||||||
assert response.usage.completion_tokens > 0
|
assert response.usage.completion_tokens > 0
|
||||||
assert response.usage.total_tokens > 0
|
assert response.usage.total_tokens > 0
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_stream(args):
|
def test_chat_completion_stream(args):
|
||||||
@@ -149,7 +154,7 @@ def test_chat_completion_stream(args):
|
|||||||
if not data.content:
|
if not data.content:
|
||||||
continue
|
continue
|
||||||
print(data.content, end="", flush=True)
|
print(data.content, end="", flush=True)
|
||||||
print()
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
def test_regex(args):
|
def test_regex(args):
|
||||||
@@ -174,6 +179,7 @@ def test_regex(args):
|
|||||||
)
|
)
|
||||||
text = response.choices[0].message.content
|
text = response.choices[0].message.content
|
||||||
print(json.loads(text))
|
print(json.loads(text))
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -188,10 +194,14 @@ if __name__ == "__main__":
|
|||||||
test_completion(args, echo=True, logprobs=False)
|
test_completion(args, echo=True, logprobs=False)
|
||||||
test_completion(args, echo=False, logprobs=True)
|
test_completion(args, echo=False, logprobs=True)
|
||||||
test_completion(args, echo=True, logprobs=True)
|
test_completion(args, echo=True, logprobs=True)
|
||||||
|
test_completion(args, echo=False, logprobs=3)
|
||||||
|
test_completion(args, echo=True, logprobs=3)
|
||||||
test_completion_stream(args, echo=False, logprobs=False)
|
test_completion_stream(args, echo=False, logprobs=False)
|
||||||
test_completion_stream(args, echo=True, logprobs=False)
|
test_completion_stream(args, echo=True, logprobs=False)
|
||||||
test_completion_stream(args, echo=False, logprobs=True)
|
test_completion_stream(args, echo=False, logprobs=True)
|
||||||
test_completion_stream(args, echo=True, logprobs=True)
|
test_completion_stream(args, echo=True, logprobs=True)
|
||||||
|
test_completion_stream(args, echo=False, logprobs=3)
|
||||||
|
test_completion_stream(args, echo=True, logprobs=3)
|
||||||
test_chat_completion(args)
|
test_chat_completion(args)
|
||||||
test_chat_completion_stream(args)
|
test_chat_completion_stream(args)
|
||||||
test_regex(args)
|
test_regex(args)
|
||||||
|
|||||||
Reference in New Issue
Block a user