Return logprob for choices (#87)
This commit is contained in:
@@ -69,6 +69,8 @@ state = multi_turn_question.run(
|
|||||||
|
|
||||||
for m in state.messages():
|
for m in state.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print(state["answer_1"])
|
||||||
```
|
```
|
||||||
|
|
||||||
### Using Local Models
|
### Using Local Models
|
||||||
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
|
|||||||
|
|
||||||
for m in state.messages():
|
for m in state.messages():
|
||||||
print(m["role"], ":", m["content"])
|
print(m["role"], ":", m["content"])
|
||||||
|
|
||||||
|
print(state["answer_1"])
|
||||||
```
|
```
|
||||||
|
|
||||||
### More Examples
|
### More Examples
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ class GenerateReqInput:
|
|||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
42
examples/usage/choices_logprob.py
Normal file
42
examples/usage/choices_logprob.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def tool_use(s, question):
|
||||||
|
s += "To answer this question: " + question + ", "
|
||||||
|
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Run one case
|
||||||
|
question = "What is 5 + 5?"
|
||||||
|
state = tool_use.run(question)
|
||||||
|
print("questions:", question)
|
||||||
|
print("choice:", state["tool"])
|
||||||
|
meta_info = state.get_meta_info("tool")
|
||||||
|
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
|
||||||
|
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
|
||||||
|
print('-' * 50)
|
||||||
|
|
||||||
|
# Run a batch
|
||||||
|
questions = [
|
||||||
|
"What is 5 + 6?",
|
||||||
|
"Who is Michael Jordan?",
|
||||||
|
]
|
||||||
|
states = tool_use.run_batch([{"question": q} for q in questions])
|
||||||
|
for question, state in zip(questions, states):
|
||||||
|
print("questions:", question)
|
||||||
|
print("choice:", state["tool"])
|
||||||
|
meta_info = state.get_meta_info("tool")
|
||||||
|
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
|
||||||
|
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
|
||||||
|
print('-' * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
|
||||||
|
main()
|
||||||
@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
|
|||||||
prompt_tokens.append(ret_token)
|
prompt_tokens.append(ret_token)
|
||||||
|
|
||||||
decision = choices[np.argmax(scores)]
|
decision = choices[np.argmax(scores)]
|
||||||
return decision, scores
|
return decision, scores, scores
|
||||||
|
|
||||||
|
|
||||||
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
|
def openai_completion(client, is_chat=None, prompt=None, **kwargs):
|
||||||
|
|||||||
@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
|
|||||||
data = {
|
data = {
|
||||||
"text": [s.text_ + c for c in choices],
|
"text": [s.text_ + c for c in choices],
|
||||||
"sampling_params": {"max_new_tokens": 0},
|
"sampling_params": {"max_new_tokens": 0},
|
||||||
"return_normalized_logprob": True,
|
"return_logprob": True,
|
||||||
"normalized_logprob_start_len": prompt_len,
|
"logprob_start_len": max(prompt_len - 2, 0),
|
||||||
}
|
}
|
||||||
self._add_images(s, data)
|
self._add_images(s, data)
|
||||||
res = http_request(self.base_url + "/generate", json=data)
|
res = http_request(self.base_url + "/generate", json=data)
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()]
|
obj = res.json()
|
||||||
|
normalized_prompt_logprob = [
|
||||||
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
||||||
|
]
|
||||||
|
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
|
||||||
|
|
||||||
decision = choices[np.argmax(logps)]
|
decision = choices[np.argmax(normalized_prompt_logprob)]
|
||||||
return decision, logps
|
return decision, normalized_prompt_logprob, prompt_logprob
|
||||||
|
|
||||||
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
||||||
res = http_request(
|
res = http_request(
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from typing import Callable, Dict, List, Tuple, Optional
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplateStyle(Enum):
|
class ChatTemplateStyle(Enum):
|
||||||
@@ -111,7 +111,7 @@ register_chat_template(
|
|||||||
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
|
||||||
},
|
},
|
||||||
style=ChatTemplateStyle.PLAIN,
|
style=ChatTemplateStyle.PLAIN,
|
||||||
stop_str=('<|im_end|>',)
|
stop_str=("<|im_end|>",),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ def run_program_batch(
|
|||||||
|
|
||||||
# Run all programs
|
# Run all programs
|
||||||
if num_threads == "auto":
|
if num_threads == "auto":
|
||||||
num_threads = max(64, multiprocessing.cpu_count() * 8)
|
num_threads = max(96, multiprocessing.cpu_count() * 16)
|
||||||
num_threads = min(num_threads, len(batch_arguments))
|
num_threads = min(num_threads, len(batch_arguments))
|
||||||
|
|
||||||
if num_threads == 1:
|
if num_threads == 1:
|
||||||
@@ -364,10 +364,16 @@ class StreamExecutor:
|
|||||||
self.stream_var_event[name].set()
|
self.stream_var_event[name].set()
|
||||||
|
|
||||||
def _execute_select(self, expr: SglSelect):
|
def _execute_select(self, expr: SglSelect):
|
||||||
decision, scores = self.backend.select(self, expr.choices, expr.temperature)
|
decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
|
||||||
|
self, expr.choices, expr.temperature
|
||||||
|
)
|
||||||
if expr.name is not None:
|
if expr.name is not None:
|
||||||
name = expr.name
|
name = expr.name
|
||||||
self.variables[name] = decision
|
self.variables[name] = decision
|
||||||
|
self.meta_info[name] = {
|
||||||
|
"normalized_prompt_logprob": normalized_prompt_logprob,
|
||||||
|
"prompt_logprob": prompt_logprob,
|
||||||
|
}
|
||||||
self.variable_event[name].set()
|
self.variable_event[name].set()
|
||||||
self.text_ += decision
|
self.text_ += decision
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
def forward(self, input_ids, hidden_states, weight, input_metadata):
|
||||||
if not input_metadata.return_normalized_logprob:
|
if not input_metadata.return_logprob:
|
||||||
if input_metadata.forward_mode == ForwardMode.DECODE:
|
if input_metadata.forward_mode == ForwardMode.DECODE:
|
||||||
last_hidden = hidden_states
|
last_hidden = hidden_states
|
||||||
else:
|
else:
|
||||||
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
|
|||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
last_logits = tensor_model_parallel_all_gather(last_logits)
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
||||||
last_logits = last_logits[:, : self.config.vocab_size]
|
last_logits = last_logits[:, : self.config.vocab_size]
|
||||||
return last_logits, None
|
return last_logits, (None, None)
|
||||||
else:
|
else:
|
||||||
assert input_metadata.forward_mode != ForwardMode.DECODE
|
assert input_metadata.forward_mode != ForwardMode.DECODE
|
||||||
last_index = (
|
last_index = (
|
||||||
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
|
|||||||
logits = logits[:, : self.config.vocab_size]
|
logits = logits[:, : self.config.vocab_size]
|
||||||
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
|
||||||
|
|
||||||
normalized_logprobs = compute_normalized_logprobs(
|
|
||||||
all_logprobs,
|
|
||||||
input_ids,
|
|
||||||
input_metadata.extend_seq_lens,
|
|
||||||
input_metadata.extend_start_loc,
|
|
||||||
)
|
|
||||||
|
|
||||||
last_logits = logits[last_index]
|
|
||||||
return last_logits, normalized_logprobs
|
|
||||||
|
|
||||||
|
|
||||||
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
|
|
||||||
logprobs = all_logprobs[
|
logprobs = all_logprobs[
|
||||||
torch.arange(all_logprobs.shape[0], device="cuda"),
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
||||||
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
||||||
]
|
]
|
||||||
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
|
||||||
|
|
||||||
start = start_loc.clone()
|
start = input_metadata.extend_start_loc.clone()
|
||||||
end = start + seq_lens - 2
|
end = start + input_metadata.extend_seq_lens - 2
|
||||||
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
start.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
end.clamp_(min=0, max=logprobs.shape[0] - 1)
|
||||||
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
|
||||||
return sum_logp / ((seq_lens - 1).clamp(min=1))
|
normalized_logprobs = sum_logp / (
|
||||||
|
(input_metadata.extend_seq_lens - 1).clamp(min=1)
|
||||||
|
)
|
||||||
|
|
||||||
|
last_logits = logits[last_index]
|
||||||
|
return last_logits, (logprobs, normalized_logprobs)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -11,8 +11,8 @@ class GenerateReqInput:
|
|||||||
image_data: Optional[Union[List[str], str]] = None
|
image_data: Optional[Union[List[str], str]] = None
|
||||||
sampling_params: Union[List[Dict], Dict] = None
|
sampling_params: Union[List[Dict], Dict] = None
|
||||||
rid: Optional[Union[List[str], str]] = None
|
rid: Optional[Union[List[str], str]] = None
|
||||||
return_normalized_logprob: Optional[Union[List[bool], bool]] = None
|
return_logprob: Optional[Union[List[bool], bool]] = None
|
||||||
normalized_logprob_start_len: Optional[Union[List[int], int]] = None
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
@@ -23,10 +23,10 @@ class GenerateReqInput:
|
|||||||
self.sampling_params = {}
|
self.sampling_params = {}
|
||||||
if self.rid is None:
|
if self.rid is None:
|
||||||
self.rid = uuid.uuid4().hex
|
self.rid = uuid.uuid4().hex
|
||||||
if self.return_normalized_logprob is None:
|
if self.return_logprob is None:
|
||||||
self.return_normalized_logprob = False
|
self.return_logprob = False
|
||||||
if self.normalized_logprob_start_len is None:
|
if self.logprob_start_len is None:
|
||||||
self.normalized_logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
else:
|
else:
|
||||||
num = len(self.text)
|
num = len(self.text)
|
||||||
|
|
||||||
@@ -45,17 +45,15 @@ class GenerateReqInput:
|
|||||||
else:
|
else:
|
||||||
assert isinstance(self.rid, list)
|
assert isinstance(self.rid, list)
|
||||||
|
|
||||||
if self.return_normalized_logprob is None:
|
if self.return_logprob is None:
|
||||||
self.return_normalized_logprob = [False] * num
|
self.return_logprob = [False] * num
|
||||||
elif not isinstance(self.return_normalized_logprob, list):
|
elif not isinstance(self.return_logprob, list):
|
||||||
self.return_normalized_logprob = [self.return_normalized_logprob] * num
|
self.return_logprob = [self.return_logprob] * num
|
||||||
|
|
||||||
if self.normalized_logprob_start_len is None:
|
if self.logprob_start_len is None:
|
||||||
self.normalized_logprob_start_len = [0] * num
|
self.logprob_start_len = [0] * num
|
||||||
elif not isinstance(self.normalized_logprob_start_len, list):
|
elif not isinstance(self.logprob_start_len, list):
|
||||||
self.normalized_logprob_start_len = [
|
self.logprob_start_len = [self.logprob_start_len] * num
|
||||||
self.normalized_logprob_start_len
|
|
||||||
] * num
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
|
|||||||
pixel_values: List[float]
|
pixel_values: List[float]
|
||||||
image_hash: int
|
image_hash: int
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
return_normalized_logprob: bool
|
return_logprob: bool
|
||||||
normalized_logprob_start_len: int
|
logprob_start_len: int
|
||||||
stream: bool
|
stream: bool
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ class Req:
|
|||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.image_offset = 0
|
self.image_offset = 0
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.return_normalized_logprob = False
|
self.return_logprob = False
|
||||||
self.normalized_logprob_start_len = 0
|
self.logprob_start_len = 0
|
||||||
self.stream = False
|
self.stream = False
|
||||||
|
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@@ -37,10 +37,11 @@ class Req:
|
|||||||
self.finish_reason = None
|
self.finish_reason = None
|
||||||
self.hit_stop_str = None
|
self.hit_stop_str = None
|
||||||
|
|
||||||
self.adjust_input_len = 0
|
self.extend_input_len = 0
|
||||||
self.prefix_indices = []
|
self.prefix_indices = []
|
||||||
self.last_node = None
|
self.last_node = None
|
||||||
|
|
||||||
|
self.logprob = None
|
||||||
self.normalized_logprob = None
|
self.normalized_logprob = None
|
||||||
|
|
||||||
# for constrained decoding
|
# for constrained decoding
|
||||||
@@ -99,7 +100,7 @@ class Batch:
|
|||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
out_cache_cont_start: torch.Tensor = None
|
out_cache_cont_start: torch.Tensor = None
|
||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
return_normalized_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
|
||||||
# for multimodal
|
# for multimodal
|
||||||
pixel_values: List[torch.Tensor] = None
|
pixel_values: List[torch.Tensor] = None
|
||||||
@@ -119,14 +120,14 @@ class Batch:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
|
||||||
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs)
|
return_logprob = any(req.return_logprob for req in reqs)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
token_to_kv_pool=token_to_kv_pool,
|
token_to_kv_pool=token_to_kv_pool,
|
||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
return_normalized_logprob=return_normalized_logprob,
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
@@ -257,7 +258,7 @@ class Batch:
|
|||||||
self.tree_cache.dec_ref_counter(req.last_node)
|
self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
req.prefix_indices = None
|
req.prefix_indices = None
|
||||||
req.last_node = None
|
req.last_node = None
|
||||||
req.adjust_input_len = 0
|
req.extend_input_len = 0
|
||||||
req.output_ids = []
|
req.output_ids = []
|
||||||
# TODO: apply more fine-grained retraction
|
# TODO: apply more fine-grained retraction
|
||||||
|
|
||||||
@@ -310,9 +311,7 @@ class Batch:
|
|||||||
self.prefix_lens = None
|
self.prefix_lens = None
|
||||||
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
self.position_ids_offsets = self.position_ids_offsets[new_indices]
|
||||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||||
self.return_normalized_logprob = any(
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
req.return_normalized_logprob for req in self.reqs
|
|
||||||
)
|
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
@@ -336,9 +335,7 @@ class Batch:
|
|||||||
[self.position_ids_offsets, other.position_ids_offsets]
|
[self.position_ids_offsets, other.position_ids_offsets]
|
||||||
)
|
)
|
||||||
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
|
||||||
self.return_normalized_logprob = any(
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
||||||
req.return_normalized_logprob for req in self.reqs
|
|
||||||
)
|
|
||||||
|
|
||||||
for item in [
|
for item in [
|
||||||
"temperatures",
|
"temperatures",
|
||||||
|
|||||||
@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req.input_ids, pad_value
|
req.input_ids, pad_value
|
||||||
)
|
)
|
||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
req.return_normalized_logprob = recv_req.return_normalized_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len
|
req.logprob_start_len = recv_req.logprob_start_len
|
||||||
req.stream = recv_req.stream
|
req.stream = recv_req.stream
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
|
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||||
if req.return_normalized_logprob:
|
if req.return_logprob:
|
||||||
prefix_indices = prefix_indices[: req.normalized_logprob_start_len]
|
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||||
req.adjust_input_len = len(req.input_ids) - len(prefix_indices)
|
req.extend_input_len = len(req.input_ids) - len(prefix_indices)
|
||||||
req.prefix_indices = prefix_indices
|
req.prefix_indices = prefix_indices
|
||||||
req.last_node = last_node
|
req.last_node = last_node
|
||||||
|
|
||||||
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for req in self.forward_queue:
|
for req in self.forward_queue:
|
||||||
if req.return_normalized_logprob:
|
if req.return_logprob:
|
||||||
# Need at least two tokens to compute normalized logprob
|
# Need at least two tokens to compute normalized logprob
|
||||||
if req.adjust_input_len < 2:
|
if req.extend_input_len < 2:
|
||||||
delta = 2 - req.adjust_input_len
|
delta = 2 - req.extend_input_len
|
||||||
req.adjust_input_len += delta
|
req.extend_input_len += delta
|
||||||
req.prefix_indices = req.prefix_indices[:-delta]
|
req.prefix_indices = req.prefix_indices[:-delta]
|
||||||
if req.image_offset is not None:
|
if req.image_offset is not None:
|
||||||
req.image_offset += delta
|
req.image_offset += delta
|
||||||
if req.adjust_input_len == 0 and req.max_new_tokens() > 0:
|
if req.extend_input_len == 0 and req.max_new_tokens() > 0:
|
||||||
# Need at least one token to compute logits
|
# Need at least one token to compute logits
|
||||||
req.adjust_input_len = 1
|
req.extend_input_len = 1
|
||||||
req.prefix_indices = req.prefix_indices[:-1]
|
req.prefix_indices = req.prefix_indices[:-1]
|
||||||
if req.image_offset is not None:
|
if req.image_offset is not None:
|
||||||
req.image_offset += 1
|
req.image_offset += 1
|
||||||
|
|
||||||
if (
|
if (
|
||||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
and req.adjust_input_len + new_batch_input_tokens
|
and req.extend_input_len + new_batch_input_tokens
|
||||||
< self.max_prefill_num_token
|
< self.max_prefill_num_token
|
||||||
):
|
):
|
||||||
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
delta = self.tree_cache.inc_ref_counter(req.last_node)
|
||||||
available_size += delta
|
available_size += delta
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens
|
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
|
||||||
< available_size
|
< available_size
|
||||||
):
|
):
|
||||||
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
delta = self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
self.token_to_kv_pool.add_refs(req.prefix_indices)
|
||||||
can_run_list.append(req)
|
can_run_list.append(req)
|
||||||
new_batch_total_tokens += (
|
new_batch_total_tokens += (
|
||||||
req.adjust_input_len + req.max_new_tokens()
|
req.extend_input_len + req.max_new_tokens()
|
||||||
)
|
)
|
||||||
new_batch_input_tokens += req.adjust_input_len
|
new_batch_input_tokens += req.extend_input_len
|
||||||
|
|
||||||
if len(can_run_list) == 0:
|
if len(can_run_list) == 0:
|
||||||
return None
|
return None
|
||||||
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
|
|
||||||
if batch.extend_num_tokens != 0:
|
if batch.extend_num_tokens != 0:
|
||||||
# Forward
|
# Forward
|
||||||
logits, normalized_logprobs = self.model_runner.forward(
|
logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
|
||||||
batch, ForwardMode.EXTEND, batch.return_normalized_logprob
|
batch, ForwardMode.EXTEND, batch.return_logprob
|
||||||
)
|
)
|
||||||
# print("extend logits", logits)
|
# print("extend logits", logits)
|
||||||
if normalized_logprobs is not None:
|
if logprobs is not None:
|
||||||
|
logprobs = logprobs.cpu().tolist()
|
||||||
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
normalized_logprobs = normalized_logprobs.cpu().tolist()
|
||||||
|
|
||||||
next_token_ids, next_token_probs = batch.sample(logits)
|
next_token_ids, next_token_probs = batch.sample(logits)
|
||||||
next_token_ids = next_token_ids.cpu().tolist()
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
else:
|
else:
|
||||||
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
|
||||||
normalized_logprobs = None
|
logprobs = normalized_logprobs = None
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
reqs = batch.reqs
|
reqs = batch.reqs
|
||||||
for i in range(len(reqs)):
|
pt = 0
|
||||||
reqs[i].output_ids = [next_token_ids[i]]
|
for i, req in enumerate(reqs):
|
||||||
reqs[i].check_finished()
|
req.output_ids = [next_token_ids[i]]
|
||||||
|
req.check_finished()
|
||||||
|
|
||||||
if normalized_logprobs is not None:
|
if logprobs is not None:
|
||||||
reqs[i].normalized_logprob = normalized_logprobs[i]
|
req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
|
||||||
|
req.normalized_logprob = normalized_logprobs[i]
|
||||||
|
pt += req.extend_input_len
|
||||||
|
|
||||||
self.handle_finished_requests(batch)
|
self.handle_finished_requests(batch)
|
||||||
|
|
||||||
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
"prompt_tokens": len(req.input_ids),
|
"prompt_tokens": len(req.input_ids),
|
||||||
"completion_tokens": len(req.output_ids),
|
"completion_tokens": len(req.output_ids),
|
||||||
}
|
}
|
||||||
if req.return_normalized_logprob:
|
if req.return_logprob:
|
||||||
meta_info["normalized_logprob"] = req.normalized_logprob
|
meta_info["prompt_logprob"] = req.logprob
|
||||||
|
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
|
||||||
output_meta_info.append(meta_info)
|
output_meta_info.append(meta_info)
|
||||||
output_finished.append(req.finished)
|
output_finished.append(req.finished)
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class InputMetadata:
|
|||||||
out_cache_cont_end: torch.Tensor = None
|
out_cache_cont_end: torch.Tensor = None
|
||||||
|
|
||||||
other_kv_index: torch.Tensor = None
|
other_kv_index: torch.Tensor = None
|
||||||
return_normalized_logprob: bool = False
|
return_logprob: bool = False
|
||||||
|
|
||||||
# for flashinfer
|
# for flashinfer
|
||||||
use_flashinfer: bool = False
|
use_flashinfer: bool = False
|
||||||
@@ -127,7 +127,7 @@ class InputMetadata:
|
|||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
out_cache_cont_start=None,
|
out_cache_cont_start=None,
|
||||||
out_cache_cont_end=None,
|
out_cache_cont_end=None,
|
||||||
return_normalized_logprob=False,
|
return_logprob=False,
|
||||||
):
|
):
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
@@ -175,7 +175,7 @@ class InputMetadata:
|
|||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
out_cache_cont_start=out_cache_cont_start,
|
out_cache_cont_start=out_cache_cont_start,
|
||||||
out_cache_cont_end=out_cache_cont_end,
|
out_cache_cont_end=out_cache_cont_end,
|
||||||
return_normalized_logprob=return_normalized_logprob,
|
return_logprob=return_logprob,
|
||||||
other_kv_index=other_kv_index,
|
other_kv_index=other_kv_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -337,7 +337,7 @@ class ModelRunner:
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
position_ids_offsets,
|
position_ids_offsets,
|
||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
return_normalized_logprob,
|
return_logprob,
|
||||||
):
|
):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
@@ -348,7 +348,7 @@ class ModelRunner:
|
|||||||
prefix_lens=prefix_lens,
|
prefix_lens=prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_normalized_logprob=return_normalized_logprob,
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
@@ -361,7 +361,7 @@ class ModelRunner:
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
position_ids_offsets,
|
position_ids_offsets,
|
||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
return_normalized_logprob,
|
return_logprob,
|
||||||
):
|
):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
@@ -372,7 +372,7 @@ class ModelRunner:
|
|||||||
prefix_lens=prefix_lens,
|
prefix_lens=prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_normalized_logprob=return_normalized_logprob,
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
return self.model.forward(input_ids, input_metadata.positions, input_metadata)
|
||||||
|
|
||||||
@@ -415,7 +415,7 @@ class ModelRunner:
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
position_ids_offsets,
|
position_ids_offsets,
|
||||||
out_cache_loc,
|
out_cache_loc,
|
||||||
return_normalized_logprob,
|
return_logprob,
|
||||||
):
|
):
|
||||||
input_metadata = InputMetadata.create(
|
input_metadata = InputMetadata.create(
|
||||||
self,
|
self,
|
||||||
@@ -426,7 +426,7 @@ class ModelRunner:
|
|||||||
prefix_lens=prefix_lens,
|
prefix_lens=prefix_lens,
|
||||||
position_ids_offsets=position_ids_offsets,
|
position_ids_offsets=position_ids_offsets,
|
||||||
out_cache_loc=out_cache_loc,
|
out_cache_loc=out_cache_loc,
|
||||||
return_normalized_logprob=return_normalized_logprob,
|
return_logprob=return_logprob,
|
||||||
)
|
)
|
||||||
return self.model.forward(
|
return self.model.forward(
|
||||||
input_ids,
|
input_ids,
|
||||||
@@ -436,9 +436,7 @@ class ModelRunner:
|
|||||||
image_offsets,
|
image_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
|
||||||
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
|
|
||||||
):
|
|
||||||
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"input_ids": batch.input_ids,
|
"input_ids": batch.input_ids,
|
||||||
@@ -450,7 +448,7 @@ class ModelRunner:
|
|||||||
"position_ids_offsets": batch.position_ids_offsets,
|
"position_ids_offsets": batch.position_ids_offsets,
|
||||||
"out_cache_loc": batch.out_cache_loc,
|
"out_cache_loc": batch.out_cache_loc,
|
||||||
}
|
}
|
||||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
kwargs["return_logprob"] = return_logprob
|
||||||
return self.forward_extend_multi_modal(**kwargs)
|
return self.forward_extend_multi_modal(**kwargs)
|
||||||
else:
|
else:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
@@ -467,10 +465,10 @@ class ModelRunner:
|
|||||||
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
|
||||||
return self.forward_decode(**kwargs)
|
return self.forward_decode(**kwargs)
|
||||||
elif forward_mode == ForwardMode.EXTEND:
|
elif forward_mode == ForwardMode.EXTEND:
|
||||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
kwargs["return_logprob"] = return_logprob
|
||||||
return self.forward_extend(**kwargs)
|
return self.forward_extend(**kwargs)
|
||||||
elif forward_mode == ForwardMode.PREFILL:
|
elif forward_mode == ForwardMode.PREFILL:
|
||||||
kwargs["return_normalized_logprob"] = return_normalized_logprob
|
kwargs["return_logprob"] = return_logprob
|
||||||
return self.forward_prefill(**kwargs)
|
return self.forward_prefill(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ class TokenizerManager:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash=image_hash,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_normalized_logprob=obj.return_normalized_logprob,
|
return_logprob=obj.return_logprob,
|
||||||
normalized_logprob_start_len=obj.normalized_logprob_start_len,
|
logprob_start_len=obj.logprob_start_len,
|
||||||
stream=obj.stream,
|
stream=obj.stream,
|
||||||
)
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
@@ -173,8 +173,8 @@ class TokenizerManager:
|
|||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
image_hash=image_hash,
|
image_hash=image_hash,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
return_normalized_logprob=obj.return_normalized_logprob[i],
|
return_logprob=obj.return_logprob[i],
|
||||||
normalized_logprob_start_len=obj.normalized_logprob_start_len[i],
|
logprob_start_len=obj.logprob_start_len[i],
|
||||||
stream=obj.stream,
|
stream=obj.stream,
|
||||||
)
|
)
|
||||||
self.send_to_router.send_pyobj(tokenized_obj)
|
self.send_to_router.send_pyobj(tokenized_obj)
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ if __name__ == "__main__":
|
|||||||
"temperature": 0,
|
"temperature": 0,
|
||||||
"max_new_tokens": 32,
|
"max_new_tokens": 32,
|
||||||
},
|
},
|
||||||
|
# "return_logprob": True,
|
||||||
|
# "logprob_start_len": 0,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
print(response.json())
|
print(response.json())
|
||||||
|
|||||||
Reference in New Issue
Block a user