Return logprob for choices (#87)

This commit is contained in:
Lianmin Zheng
2024-01-23 05:07:30 -08:00
committed by GitHub
parent 9e037c822c
commit 9a16fea012
15 changed files with 161 additions and 112 deletions

View File

@@ -69,6 +69,8 @@ state = multi_turn_question.run(
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print(state["answer_1"])
``` ```
### Using Local Models ### Using Local Models
@@ -99,6 +101,8 @@ state = multi_turn_question.run(
for m in state.messages(): for m in state.messages():
print(m["role"], ":", m["content"]) print(m["role"], ":", m["content"])
print(state["answer_1"])
``` ```
### More Examples ### More Examples

View File

@@ -9,8 +9,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False stream: bool = False
``` ```

View File

@@ -0,0 +1,42 @@
"""
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
"""
import sglang as sgl
@sgl.function
def tool_use(s, question):
s += "To answer this question: " + question + ", "
s += "I need to use a " + sgl.gen("tool", choices=["calculator", "search engine"])
def main():
# Run one case
question = "What is 5 + 5?"
state = tool_use.run(question)
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)
# Run a batch
questions = [
"What is 5 + 6?",
"Who is Michael Jordan?",
]
states = tool_use.run_batch([{"question": q} for q in questions])
for question, state in zip(questions, states):
print("questions:", question)
print("choice:", state["tool"])
meta_info = state.get_meta_info("tool")
print("logprobs of choice 1", meta_info["prompt_logprob"][0])
print("logprobs of choice 2", meta_info["prompt_logprob"][1])
print('-' * 50)
if __name__ == "__main__":
sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000"))
main()

View File

@@ -209,7 +209,7 @@ class OpenAI(BaseBackend):
prompt_tokens.append(ret_token) prompt_tokens.append(ret_token)
decision = choices[np.argmax(scores)] decision = choices[np.argmax(scores)]
return decision, scores return decision, scores, scores
def openai_completion(client, is_chat=None, prompt=None, **kwargs): def openai_completion(client, is_chat=None, prompt=None, **kwargs):

View File

@@ -150,16 +150,20 @@ class RuntimeEndpoint(BaseBackend):
data = { data = {
"text": [s.text_ + c for c in choices], "text": [s.text_ + c for c in choices],
"sampling_params": {"max_new_tokens": 0}, "sampling_params": {"max_new_tokens": 0},
"return_normalized_logprob": True, "return_logprob": True,
"normalized_logprob_start_len": prompt_len, "logprob_start_len": max(prompt_len - 2, 0),
} }
self._add_images(s, data) self._add_images(s, data)
res = http_request(self.base_url + "/generate", json=data) res = http_request(self.base_url + "/generate", json=data)
assert res.status_code == 200 assert res.status_code == 200
logps = [r["meta_info"]["normalized_logprob"] for r in res.json()] obj = res.json()
normalized_prompt_logprob = [
r["meta_info"]["normalized_prompt_logprob"] for r in obj
]
prompt_logprob = [r["meta_info"]["prompt_logprob"] for r in obj]
decision = choices[np.argmax(logps)] decision = choices[np.argmax(normalized_prompt_logprob)]
return decision, logps return decision, normalized_prompt_logprob, prompt_logprob
def concatenate_and_append(self, src_rids: List[str], dst_rid: str): def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
res = http_request( res = http_request(

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, Dict, List, Tuple, Optional from typing import Callable, Dict, List, Optional, Tuple
class ChatTemplateStyle(Enum): class ChatTemplateStyle(Enum):
@@ -111,7 +111,7 @@ register_chat_template(
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
}, },
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=('<|im_end|>',) stop_str=("<|im_end|>",),
) )
) )

View File

@@ -80,7 +80,7 @@ def run_program_batch(
# Run all programs # Run all programs
if num_threads == "auto": if num_threads == "auto":
num_threads = max(64, multiprocessing.cpu_count() * 8) num_threads = max(96, multiprocessing.cpu_count() * 16)
num_threads = min(num_threads, len(batch_arguments)) num_threads = min(num_threads, len(batch_arguments))
if num_threads == 1: if num_threads == 1:
@@ -364,10 +364,16 @@ class StreamExecutor:
self.stream_var_event[name].set() self.stream_var_event[name].set()
def _execute_select(self, expr: SglSelect): def _execute_select(self, expr: SglSelect):
decision, scores = self.backend.select(self, expr.choices, expr.temperature) decision, normalized_prompt_logprob, prompt_logprob = self.backend.select(
self, expr.choices, expr.temperature
)
if expr.name is not None: if expr.name is not None:
name = expr.name name = expr.name
self.variables[name] = decision self.variables[name] = decision
self.meta_info[name] = {
"normalized_prompt_logprob": normalized_prompt_logprob,
"prompt_logprob": prompt_logprob,
}
self.variable_event[name].set() self.variable_event[name].set()
self.text_ += decision self.text_ += decision

View File

@@ -14,7 +14,7 @@ class LogitsProcessor(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
def forward(self, input_ids, hidden_states, weight, input_metadata): def forward(self, input_ids, hidden_states, weight, input_metadata):
if not input_metadata.return_normalized_logprob: if not input_metadata.return_logprob:
if input_metadata.forward_mode == ForwardMode.DECODE: if input_metadata.forward_mode == ForwardMode.DECODE:
last_hidden = hidden_states last_hidden = hidden_states
else: else:
@@ -33,7 +33,7 @@ class LogitsProcessor(nn.Module):
if self.tp_size > 1: if self.tp_size > 1:
last_logits = tensor_model_parallel_all_gather(last_logits) last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size] last_logits = last_logits[:, : self.config.vocab_size]
return last_logits, None return last_logits, (None, None)
else: else:
assert input_metadata.forward_mode != ForwardMode.DECODE assert input_metadata.forward_mode != ForwardMode.DECODE
last_index = ( last_index = (
@@ -51,30 +51,23 @@ class LogitsProcessor(nn.Module):
logits = logits[:, : self.config.vocab_size] logits = logits[:, : self.config.vocab_size]
all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6) all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
normalized_logprobs = compute_normalized_logprobs(
all_logprobs,
input_ids,
input_metadata.extend_seq_lens,
input_metadata.extend_start_loc,
)
last_logits = logits[last_index]
return last_logits, normalized_logprobs
def compute_normalized_logprobs(all_logprobs, input_ids, seq_lens, start_loc):
logprobs = all_logprobs[ logprobs = all_logprobs[
torch.arange(all_logprobs.shape[0], device="cuda"), torch.arange(all_logprobs.shape[0], device="cuda"),
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
] ]
logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32) logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
start = start_loc.clone() start = input_metadata.extend_start_loc.clone()
end = start + seq_lens - 2 end = start + input_metadata.extend_seq_lens - 2
start.clamp_(min=0, max=logprobs.shape[0] - 1) start.clamp_(min=0, max=logprobs.shape[0] - 1)
end.clamp_(min=0, max=logprobs.shape[0] - 1) end.clamp_(min=0, max=logprobs.shape[0] - 1)
sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start] sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
return sum_logp / ((seq_lens - 1).clamp(min=1)) normalized_logprobs = sum_logp / (
(input_metadata.extend_seq_lens - 1).clamp(min=1)
)
last_logits = logits[last_index]
return last_logits, (logprobs, normalized_logprobs)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -11,8 +11,8 @@ class GenerateReqInput:
image_data: Optional[Union[List[str], str]] = None image_data: Optional[Union[List[str], str]] = None
sampling_params: Union[List[Dict], Dict] = None sampling_params: Union[List[Dict], Dict] = None
rid: Optional[Union[List[str], str]] = None rid: Optional[Union[List[str], str]] = None
return_normalized_logprob: Optional[Union[List[bool], bool]] = None return_logprob: Optional[Union[List[bool], bool]] = None
normalized_logprob_start_len: Optional[Union[List[int], int]] = None logprob_start_len: Optional[Union[List[int], int]] = None
stream: bool = False stream: bool = False
def post_init(self): def post_init(self):
@@ -23,10 +23,10 @@ class GenerateReqInput:
self.sampling_params = {} self.sampling_params = {}
if self.rid is None: if self.rid is None:
self.rid = uuid.uuid4().hex self.rid = uuid.uuid4().hex
if self.return_normalized_logprob is None: if self.return_logprob is None:
self.return_normalized_logprob = False self.return_logprob = False
if self.normalized_logprob_start_len is None: if self.logprob_start_len is None:
self.normalized_logprob_start_len = 0 self.logprob_start_len = 0
else: else:
num = len(self.text) num = len(self.text)
@@ -45,17 +45,15 @@ class GenerateReqInput:
else: else:
assert isinstance(self.rid, list) assert isinstance(self.rid, list)
if self.return_normalized_logprob is None: if self.return_logprob is None:
self.return_normalized_logprob = [False] * num self.return_logprob = [False] * num
elif not isinstance(self.return_normalized_logprob, list): elif not isinstance(self.return_logprob, list):
self.return_normalized_logprob = [self.return_normalized_logprob] * num self.return_logprob = [self.return_logprob] * num
if self.normalized_logprob_start_len is None: if self.logprob_start_len is None:
self.normalized_logprob_start_len = [0] * num self.logprob_start_len = [0] * num
elif not isinstance(self.normalized_logprob_start_len, list): elif not isinstance(self.logprob_start_len, list):
self.normalized_logprob_start_len = [ self.logprob_start_len = [self.logprob_start_len] * num
self.normalized_logprob_start_len
] * num
@dataclass @dataclass
@@ -65,8 +63,8 @@ class TokenizedGenerateReqInput:
pixel_values: List[float] pixel_values: List[float]
image_hash: int image_hash: int
sampling_params: SamplingParams sampling_params: SamplingParams
return_normalized_logprob: bool return_logprob: bool
normalized_logprob_start_len: int logprob_start_len: int
stream: bool stream: bool

View File

@@ -28,8 +28,8 @@ class Req:
self.pixel_values = None self.pixel_values = None
self.image_offset = 0 self.image_offset = 0
self.sampling_params = None self.sampling_params = None
self.return_normalized_logprob = False self.return_logprob = False
self.normalized_logprob_start_len = 0 self.logprob_start_len = 0
self.stream = False self.stream = False
self.tokenizer = None self.tokenizer = None
@@ -37,10 +37,11 @@ class Req:
self.finish_reason = None self.finish_reason = None
self.hit_stop_str = None self.hit_stop_str = None
self.adjust_input_len = 0 self.extend_input_len = 0
self.prefix_indices = [] self.prefix_indices = []
self.last_node = None self.last_node = None
self.logprob = None
self.normalized_logprob = None self.normalized_logprob = None
# for constrained decoding # for constrained decoding
@@ -99,7 +100,7 @@ class Batch:
out_cache_loc: torch.Tensor = None out_cache_loc: torch.Tensor = None
out_cache_cont_start: torch.Tensor = None out_cache_cont_start: torch.Tensor = None
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
return_normalized_logprob: bool = False return_logprob: bool = False
# for multimodal # for multimodal
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
@@ -119,14 +120,14 @@ class Batch:
@classmethod @classmethod
def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache): def init_new(cls, reqs, req_to_token_pool, token_to_kv_pool, tree_cache):
return_normalized_logprob = any(req.return_normalized_logprob for req in reqs) return_logprob = any(req.return_logprob for req in reqs)
return cls( return cls(
reqs=reqs, reqs=reqs,
req_to_token_pool=req_to_token_pool, req_to_token_pool=req_to_token_pool,
token_to_kv_pool=token_to_kv_pool, token_to_kv_pool=token_to_kv_pool,
tree_cache=tree_cache, tree_cache=tree_cache,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
def is_empty(self): def is_empty(self):
@@ -257,7 +258,7 @@ class Batch:
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
req.prefix_indices = None req.prefix_indices = None
req.last_node = None req.last_node = None
req.adjust_input_len = 0 req.extend_input_len = 0
req.output_ids = [] req.output_ids = []
# TODO: apply more fine-grained retraction # TODO: apply more fine-grained retraction
@@ -310,9 +311,7 @@ class Batch:
self.prefix_lens = None self.prefix_lens = None
self.position_ids_offsets = self.position_ids_offsets[new_indices] self.position_ids_offsets = self.position_ids_offsets[new_indices]
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any( self.return_logprob = any(req.return_logprob for req in self.reqs)
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",
@@ -336,9 +335,7 @@ class Batch:
[self.position_ids_offsets, other.position_ids_offsets] [self.position_ids_offsets, other.position_ids_offsets]
) )
self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
self.return_normalized_logprob = any( self.return_logprob = any(req.return_logprob for req in self.reqs)
req.return_normalized_logprob for req in self.reqs
)
for item in [ for item in [
"temperatures", "temperatures",

View File

@@ -214,8 +214,8 @@ class ModelRpcServer(rpyc.Service):
req.input_ids, pad_value req.input_ids, pad_value
) )
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_normalized_logprob = recv_req.return_normalized_logprob req.return_logprob = recv_req.return_logprob
req.normalized_logprob_start_len = recv_req.normalized_logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
req.stream = recv_req.stream req.stream = recv_req.stream
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
@@ -240,9 +240,9 @@ class ModelRpcServer(rpyc.Service):
for req in self.forward_queue: for req in self.forward_queue:
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids) prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
if req.return_normalized_logprob: if req.return_logprob:
prefix_indices = prefix_indices[: req.normalized_logprob_start_len] prefix_indices = prefix_indices[: req.logprob_start_len]
req.adjust_input_len = len(req.input_ids) - len(prefix_indices) req.extend_input_len = len(req.input_ids) - len(prefix_indices)
req.prefix_indices = prefix_indices req.prefix_indices = prefix_indices
req.last_node = last_node req.last_node = last_node
@@ -267,32 +267,32 @@ class ModelRpcServer(rpyc.Service):
) )
for req in self.forward_queue: for req in self.forward_queue:
if req.return_normalized_logprob: if req.return_logprob:
# Need at least two tokens to compute normalized logprob # Need at least two tokens to compute normalized logprob
if req.adjust_input_len < 2: if req.extend_input_len < 2:
delta = 2 - req.adjust_input_len delta = 2 - req.extend_input_len
req.adjust_input_len += delta req.extend_input_len += delta
req.prefix_indices = req.prefix_indices[:-delta] req.prefix_indices = req.prefix_indices[:-delta]
if req.image_offset is not None: if req.image_offset is not None:
req.image_offset += delta req.image_offset += delta
if req.adjust_input_len == 0 and req.max_new_tokens() > 0: if req.extend_input_len == 0 and req.max_new_tokens() > 0:
# Need at least one token to compute logits # Need at least one token to compute logits
req.adjust_input_len = 1 req.extend_input_len = 1
req.prefix_indices = req.prefix_indices[:-1] req.prefix_indices = req.prefix_indices[:-1]
if req.image_offset is not None: if req.image_offset is not None:
req.image_offset += 1 req.image_offset += 1
if ( if (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
and req.adjust_input_len + new_batch_input_tokens and req.extend_input_len + new_batch_input_tokens
< self.max_prefill_num_token < self.max_prefill_num_token
): ):
delta = self.tree_cache.inc_ref_counter(req.last_node) delta = self.tree_cache.inc_ref_counter(req.last_node)
available_size += delta available_size += delta
if not ( if not (
req.adjust_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
): ):
delta = self.tree_cache.dec_ref_counter(req.last_node) delta = self.tree_cache.dec_ref_counter(req.last_node)
@@ -301,9 +301,9 @@ class ModelRpcServer(rpyc.Service):
self.token_to_kv_pool.add_refs(req.prefix_indices) self.token_to_kv_pool.add_refs(req.prefix_indices)
can_run_list.append(req) can_run_list.append(req)
new_batch_total_tokens += ( new_batch_total_tokens += (
req.adjust_input_len + req.max_new_tokens() req.extend_input_len + req.max_new_tokens()
) )
new_batch_input_tokens += req.adjust_input_len new_batch_input_tokens += req.extend_input_len
if len(can_run_list) == 0: if len(can_run_list) == 0:
return None return None
@@ -339,27 +339,31 @@ class ModelRpcServer(rpyc.Service):
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
# Forward # Forward
logits, normalized_logprobs = self.model_runner.forward( logits, (logprobs, normalized_logprobs) = self.model_runner.forward(
batch, ForwardMode.EXTEND, batch.return_normalized_logprob batch, ForwardMode.EXTEND, batch.return_logprob
) )
# print("extend logits", logits) # print("extend logits", logits)
if normalized_logprobs is not None: if logprobs is not None:
logprobs = logprobs.cpu().tolist()
normalized_logprobs = normalized_logprobs.cpu().tolist() normalized_logprobs = normalized_logprobs.cpu().tolist()
next_token_ids, next_token_probs = batch.sample(logits) next_token_ids, next_token_probs = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
else: else:
next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
normalized_logprobs = None logprobs = normalized_logprobs = None
# Check finish condition # Check finish condition
reqs = batch.reqs reqs = batch.reqs
for i in range(len(reqs)): pt = 0
reqs[i].output_ids = [next_token_ids[i]] for i, req in enumerate(reqs):
reqs[i].check_finished() req.output_ids = [next_token_ids[i]]
req.check_finished()
if normalized_logprobs is not None: if logprobs is not None:
reqs[i].normalized_logprob = normalized_logprobs[i] req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
req.normalized_logprob = normalized_logprobs[i]
pt += req.extend_input_len
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
@@ -427,8 +431,9 @@ class ModelRpcServer(rpyc.Service):
"prompt_tokens": len(req.input_ids), "prompt_tokens": len(req.input_ids),
"completion_tokens": len(req.output_ids), "completion_tokens": len(req.output_ids),
} }
if req.return_normalized_logprob: if req.return_logprob:
meta_info["normalized_logprob"] = req.normalized_logprob meta_info["prompt_logprob"] = req.logprob
meta_info["normalized_prompt_logprob"] = req.normalized_logprob
output_meta_info.append(meta_info) output_meta_info.append(meta_info)
output_finished.append(req.finished) output_finished.append(req.finished)

View File

@@ -45,7 +45,7 @@ class InputMetadata:
out_cache_cont_end: torch.Tensor = None out_cache_cont_end: torch.Tensor = None
other_kv_index: torch.Tensor = None other_kv_index: torch.Tensor = None
return_normalized_logprob: bool = False return_logprob: bool = False
# for flashinfer # for flashinfer
use_flashinfer: bool = False use_flashinfer: bool = False
@@ -127,7 +127,7 @@ class InputMetadata:
out_cache_loc, out_cache_loc,
out_cache_cont_start=None, out_cache_cont_start=None,
out_cache_cont_end=None, out_cache_cont_end=None,
return_normalized_logprob=False, return_logprob=False,
): ):
batch_size = len(req_pool_indices) batch_size = len(req_pool_indices)
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
@@ -175,7 +175,7 @@ class InputMetadata:
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
out_cache_cont_start=out_cache_cont_start, out_cache_cont_start=out_cache_cont_start,
out_cache_cont_end=out_cache_cont_end, out_cache_cont_end=out_cache_cont_end,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
other_kv_index=other_kv_index, other_kv_index=other_kv_index,
) )
@@ -337,7 +337,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
@@ -348,7 +348,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata) return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@@ -361,7 +361,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
@@ -372,7 +372,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward(input_ids, input_metadata.positions, input_metadata) return self.model.forward(input_ids, input_metadata.positions, input_metadata)
@@ -415,7 +415,7 @@ class ModelRunner:
prefix_lens, prefix_lens,
position_ids_offsets, position_ids_offsets,
out_cache_loc, out_cache_loc,
return_normalized_logprob, return_logprob,
): ):
input_metadata = InputMetadata.create( input_metadata = InputMetadata.create(
self, self,
@@ -426,7 +426,7 @@ class ModelRunner:
prefix_lens=prefix_lens, prefix_lens=prefix_lens,
position_ids_offsets=position_ids_offsets, position_ids_offsets=position_ids_offsets,
out_cache_loc=out_cache_loc, out_cache_loc=out_cache_loc,
return_normalized_logprob=return_normalized_logprob, return_logprob=return_logprob,
) )
return self.model.forward( return self.model.forward(
input_ids, input_ids,
@@ -436,9 +436,7 @@ class ModelRunner:
image_offsets, image_offsets,
) )
def forward( def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
self, batch: Batch, forward_mode: ForwardMode, return_normalized_logprob=False
):
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
kwargs = { kwargs = {
"input_ids": batch.input_ids, "input_ids": batch.input_ids,
@@ -450,7 +448,7 @@ class ModelRunner:
"position_ids_offsets": batch.position_ids_offsets, "position_ids_offsets": batch.position_ids_offsets,
"out_cache_loc": batch.out_cache_loc, "out_cache_loc": batch.out_cache_loc,
} }
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_extend_multi_modal(**kwargs) return self.forward_extend_multi_modal(**kwargs)
else: else:
kwargs = { kwargs = {
@@ -467,10 +465,10 @@ class ModelRunner:
kwargs["out_cache_cont_end"] = batch.out_cache_cont_end kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
return self.forward_decode(**kwargs) return self.forward_decode(**kwargs)
elif forward_mode == ForwardMode.EXTEND: elif forward_mode == ForwardMode.EXTEND:
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_extend(**kwargs) return self.forward_extend(**kwargs)
elif forward_mode == ForwardMode.PREFILL: elif forward_mode == ForwardMode.PREFILL:
kwargs["return_normalized_logprob"] = return_normalized_logprob kwargs["return_logprob"] = return_logprob
return self.forward_prefill(**kwargs) return self.forward_prefill(**kwargs)
else: else:
raise ValueError(f"Invaid forward mode: {forward_mode}") raise ValueError(f"Invaid forward mode: {forward_mode}")

View File

@@ -132,8 +132,8 @@ class TokenizerManager:
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
sampling_params=sampling_params, sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob, return_logprob=obj.return_logprob,
normalized_logprob_start_len=obj.normalized_logprob_start_len, logprob_start_len=obj.logprob_start_len,
stream=obj.stream, stream=obj.stream,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)
@@ -173,8 +173,8 @@ class TokenizerManager:
pixel_values=pixel_values, pixel_values=pixel_values,
image_hash=image_hash, image_hash=image_hash,
sampling_params=sampling_params, sampling_params=sampling_params,
return_normalized_logprob=obj.return_normalized_logprob[i], return_logprob=obj.return_logprob[i],
normalized_logprob_start_len=obj.normalized_logprob_start_len[i], logprob_start_len=obj.logprob_start_len[i],
stream=obj.stream, stream=obj.stream,
) )
self.send_to_router.send_pyobj(tokenized_obj) self.send_to_router.send_pyobj(tokenized_obj)

View File

@@ -26,6 +26,8 @@ if __name__ == "__main__":
"temperature": 0, "temperature": 0,
"max_new_tokens": 32, "max_new_tokens": 32,
}, },
# "return_logprob": True,
# "logprob_start_len": 0,
}, },
) )
print(response.json()) print(response.json())