Optimize retract (#440)
This commit is contained in:
@@ -28,5 +28,11 @@ class GlobalConfig:
|
||||
# Request dependency time due to network delay
|
||||
self.request_dependency_time = 0.03
|
||||
|
||||
# New generation token ratio estimation
|
||||
self.base_new_token_ratio = 0.4
|
||||
self.base_min_new_token_ratio = 0.2
|
||||
self.new_token_ratio_decay = 0.0001
|
||||
self.new_token_ratio_recovery = 0.05
|
||||
|
||||
|
||||
global_config = GlobalConfig()
|
||||
|
||||
@@ -50,21 +50,22 @@ class LogitsProcessor(nn.Module):
|
||||
prefill_top_logprobs, decode_top_logprobs = [], []
|
||||
pt = 0
|
||||
# NOTE: the GPU-CPU overhead can be reduced
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
|
||||
for i in range(len(extend_seq_lens_cpu)):
|
||||
if extend_seq_lens_cpu[i] == 0:
|
||||
extend_seq_lens_cpu = input_metadata.extend_seq_lens.tolist()
|
||||
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
||||
if extend_seq_len == 0:
|
||||
prefill_top_logprobs.append([])
|
||||
decode_top_logprobs.append([])
|
||||
continue
|
||||
k = input_metadata.top_logprobs_nums[i]
|
||||
t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
|
||||
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
||||
vs_cpu = t.values.tolist()
|
||||
ps_cpu = t.indices.tolist()
|
||||
prefill_top_logprobs.append(
|
||||
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
||||
)
|
||||
decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
||||
pt += extend_seq_lens_cpu[i]
|
||||
pt += extend_seq_len
|
||||
|
||||
return prefill_top_logprobs, decode_top_logprobs
|
||||
|
||||
def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
|
||||
@@ -145,7 +146,7 @@ class LogitsProcessor(nn.Module):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def test():
|
||||
all_logprobs = torch.tensor(
|
||||
# s s s
|
||||
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
|
||||
@@ -173,3 +174,7 @@ if __name__ == "__main__":
|
||||
print("start", start)
|
||||
print("end", end)
|
||||
print("sum_logp", sum_logp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
|
||||
@@ -51,11 +51,6 @@ class DetokenizerManager:
|
||||
# Trim stop str
|
||||
# TODO(lmzheng): handle the case where multiple stop strs are hit
|
||||
for i in range(len(output_strs)):
|
||||
if recv_obj.hit_stop_str[i] is not None:
|
||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
||||
if pos != -1:
|
||||
output_strs[i] = output_strs[i][:pos]
|
||||
|
||||
if len(output_tokens[i]) > 0:
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(
|
||||
int(output_tokens[i][0])
|
||||
@@ -65,9 +60,12 @@ class DetokenizerManager:
|
||||
if first_token.startswith("▁"):
|
||||
output_strs[i] = " " + output_strs[i]
|
||||
|
||||
output_strs[i] = (
|
||||
recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
|
||||
)
|
||||
output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
|
||||
|
||||
if recv_obj.hit_stop_str[i] is not None:
|
||||
pos = output_strs[i].find(recv_obj.hit_stop_str[i])
|
||||
if pos != -1:
|
||||
output_strs[i] = output_strs[i][:pos]
|
||||
|
||||
self.send_to_tokenizer.send_pyobj(
|
||||
BatchStrOut(
|
||||
|
||||
@@ -106,8 +106,8 @@ class TokenizedGenerateReqInput:
|
||||
@dataclass
|
||||
class BatchTokenIDOut:
|
||||
rids: List[str]
|
||||
prev_output_strs : List[str]
|
||||
output_tokens: List[List[int]]
|
||||
output_and_jump_forward_strs: List[str]
|
||||
hit_stop_str: List[Optional[str]]
|
||||
skip_special_tokens: List[bool]
|
||||
spaces_between_special_tokens: List[bool]
|
||||
|
||||
@@ -36,15 +36,15 @@ class FinishReason(IntEnum):
|
||||
|
||||
|
||||
class Req:
|
||||
def __init__(self, rid, input_text, input_ids):
|
||||
def __init__(self, rid, origin_input_text, origin_input_ids):
|
||||
self.rid = rid
|
||||
self.input_text = input_text
|
||||
self.input_ids = input_ids
|
||||
self.origin_input_text = origin_input_text
|
||||
self.origin_input_ids = origin_input_ids
|
||||
self.origin_input_ids_unpadded = origin_input_ids # before image padding
|
||||
self.prev_output_str = ""
|
||||
self.prev_output_ids = []
|
||||
self.output_ids = []
|
||||
|
||||
# Since jump forward may retokenize the prompt with partial outputs,
|
||||
# we maintain the original prompt length to report the correct usage.
|
||||
self.prompt_tokens = len(input_ids)
|
||||
self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
|
||||
|
||||
# The number of decoded tokens for token usage report. Note that
|
||||
# this does not include the jump forward tokens.
|
||||
@@ -76,15 +76,24 @@ class Req:
|
||||
self.top_logprobs_num = 0
|
||||
self.normalized_prompt_logprob = None
|
||||
self.prefill_token_logprobs = None
|
||||
self.decode_token_logprobs = None
|
||||
self.decode_token_logprobs = []
|
||||
self.prefill_top_logprobs = None
|
||||
self.decode_top_logprobs = None
|
||||
self.decode_top_logprobs = []
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
# Constrained decoding
|
||||
self.regex_fsm = None
|
||||
self.regex_fsm_state = 0
|
||||
self.jump_forward_map = None
|
||||
self.output_and_jump_forward_str = ""
|
||||
|
||||
def partial_decode(self, ids):
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
|
||||
first_token = (
|
||||
first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||
)
|
||||
return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
|
||||
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
@@ -93,7 +102,10 @@ class Req:
|
||||
if self.finished:
|
||||
return
|
||||
|
||||
if len(self.output_ids) >= self.sampling_params.max_new_tokens:
|
||||
if (
|
||||
len(self.prev_output_ids) + len(self.output_ids)
|
||||
>= self.sampling_params.max_new_tokens
|
||||
):
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.LENGTH
|
||||
return
|
||||
@@ -112,60 +124,66 @@ class Req:
|
||||
)
|
||||
|
||||
for stop_str in self.sampling_params.stop_strs:
|
||||
if stop_str in tail_str:
|
||||
# FIXME: (minor) try incremental match in prev_output_str
|
||||
if stop_str in tail_str or stop_str in self.prev_output_str:
|
||||
self.finished = True
|
||||
self.finish_reason = FinishReason.STOP_STR
|
||||
self.hit_stop_str = stop_str
|
||||
return
|
||||
|
||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||
# FIXME: This logic does not really solve the problem of determining whether
|
||||
# there should be a leading space.
|
||||
first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
|
||||
first_token = (
|
||||
first_token.decode() if isinstance(first_token, bytes) else first_token
|
||||
)
|
||||
if first_token.startswith("▁"):
|
||||
old_output_str = " " + old_output_str
|
||||
if self.input_text is None:
|
||||
# TODO(lmzheng): This can be wrong. Check with Liangsheng.
|
||||
self.input_text = self.tokenizer.decode(self.input_ids)
|
||||
new_input_string = (
|
||||
self.input_text
|
||||
+ self.output_and_jump_forward_str
|
||||
+ old_output_str
|
||||
cur_output_str = self.partial_decode(self.output_ids)
|
||||
|
||||
# TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
|
||||
if self.origin_input_text is None:
|
||||
# Recovering text can only use unpadded ids
|
||||
self.origin_input_text = self.tokenizer.decode(
|
||||
self.origin_input_ids_unpadded
|
||||
)
|
||||
|
||||
all_text = (
|
||||
self.origin_input_text
|
||||
+ self.prev_output_str
|
||||
+ cur_output_str
|
||||
+ jump_forward_str
|
||||
)
|
||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||
if self.pixel_values is not None:
|
||||
# NOTE: This is a hack because the old input_ids contains the image padding
|
||||
jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
|
||||
else:
|
||||
jump_forward_tokens_len = (
|
||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||
)
|
||||
all_ids = self.tokenizer.encode(all_text)
|
||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||
self.origin_input_ids = all_ids[:prompt_tokens]
|
||||
self.origin_input_ids_unpadded = self.origin_input_ids
|
||||
# NOTE: the output ids may not strictly correspond to the output text
|
||||
old_prev_output_ids = self.prev_output_ids
|
||||
self.prev_output_ids = all_ids[prompt_tokens:]
|
||||
self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
|
||||
self.output_ids = []
|
||||
|
||||
self.regex_fsm_state = next_state
|
||||
|
||||
if self.return_logprob:
|
||||
# For fast-forward part's logprobs
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_prev_output_ids):
|
||||
if old_id == self.prev_output_ids[i]:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.decode_token_logprobs = self.decode_token_logprobs[:k]
|
||||
self.decode_top_logprobs = self.decode_top_logprobs[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.prev_output_ids) - k
|
||||
|
||||
# print("=" * 100)
|
||||
# print(f"Catch jump forward:\n{jump_forward_str}")
|
||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||
# print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
|
||||
|
||||
self.input_ids = new_input_ids
|
||||
self.output_ids = []
|
||||
self.sampling_params.max_new_tokens = max(
|
||||
self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
|
||||
)
|
||||
self.regex_fsm_state = next_state
|
||||
self.output_and_jump_forward_str = (
|
||||
self.output_and_jump_forward_str + old_output_str + jump_forward_str
|
||||
)
|
||||
|
||||
# print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
|
||||
# print("*" * 100)
|
||||
|
||||
def __repr__(self):
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
|
||||
return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -336,6 +354,7 @@ class Batch:
|
||||
|
||||
def retract_decode(self):
|
||||
sorted_indices = [i for i in range(len(self.reqs))]
|
||||
# TODO(lsyin): improve the priority of retraction
|
||||
sorted_indices.sort(
|
||||
key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
|
||||
reverse=True,
|
||||
@@ -356,18 +375,27 @@ class Batch:
|
||||
][last_uncached_pos : seq_lens_cpu[idx]]
|
||||
self.token_to_kv_pool.dec_refs(token_indices)
|
||||
|
||||
# release the last node
|
||||
self.tree_cache.dec_lock_ref(req.last_node)
|
||||
|
||||
cur_output_str = req.partial_decode(req.output_ids)
|
||||
req.prev_output_str = req.prev_output_str + cur_output_str
|
||||
req.prev_output_ids.extend(req.output_ids)
|
||||
|
||||
req.prefix_indices = None
|
||||
req.last_node = None
|
||||
req.extend_input_len = 0
|
||||
req.output_ids = []
|
||||
req.regex_fsm_state = 0
|
||||
|
||||
# For incremental logprobs
|
||||
req.last_update_decode_tokens = 0
|
||||
req.logprob_start_len = 10**9
|
||||
|
||||
self.filter_batch(sorted_indices)
|
||||
|
||||
return retracted_reqs
|
||||
|
||||
def check_for_jump_forward(self):
|
||||
def check_for_jump_forward(self, model_runner):
|
||||
jump_forward_reqs = []
|
||||
filter_indices = [i for i in range(len(self.reqs))]
|
||||
|
||||
@@ -397,6 +425,18 @@ class Batch:
|
||||
# jump-forward
|
||||
req.jump_forward_and_retokenize(jump_forward_str, next_state)
|
||||
|
||||
# re-applying image padding
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.origin_input_ids,
|
||||
req.image_offset,
|
||||
) = model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import multiprocessing
|
||||
import time
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import rpyc
|
||||
import torch
|
||||
@@ -16,6 +16,7 @@ try:
|
||||
except ImportError:
|
||||
from vllm.logger import logger as vllm_default_logger
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.constrained.fsm_cache import FSMCache
|
||||
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
@@ -106,7 +107,8 @@ class ModelRpcServer:
|
||||
set_random_seed(server_args.random_seed)
|
||||
|
||||
# Print info
|
||||
logger.info(f"[rank={self.tp_rank}] "
|
||||
logger.info(
|
||||
f"[rank={self.tp_rank}] "
|
||||
f"max_total_num_token={self.max_total_num_token}, "
|
||||
f"max_prefill_num_token={self.max_prefill_num_token}, "
|
||||
f"context_len={self.model_config.context_len}, "
|
||||
@@ -151,9 +153,20 @@ class ModelRpcServer:
|
||||
self.jump_forward_cache = JumpForwardCache()
|
||||
|
||||
# Init new token estimation
|
||||
self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
|
||||
self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
|
||||
self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
|
||||
assert (
|
||||
server_args.schedule_conservativeness >= 0
|
||||
), "Invalid schedule_conservativeness"
|
||||
self.new_token_ratio = min(
|
||||
global_config.base_new_token_ratio * server_args.schedule_conservativeness,
|
||||
1.0,
|
||||
)
|
||||
self.min_new_token_ratio = min(
|
||||
global_config.base_min_new_token_ratio
|
||||
* server_args.schedule_conservativeness,
|
||||
1.0,
|
||||
)
|
||||
self.new_token_ratio_decay = global_config.new_token_ratio_decay
|
||||
self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
|
||||
|
||||
def exposed_step(self, recv_reqs):
|
||||
if self.tp_size != 1:
|
||||
@@ -256,8 +269,13 @@ class ModelRpcServer:
|
||||
(recv_req.image_hash >> 64) % self.model_config.vocab_size,
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
|
||||
req.origin_input_ids, req.image_offset = (
|
||||
self.model_runner.model.pad_input_ids(
|
||||
req.origin_input_ids_unpadded,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
@@ -275,11 +293,11 @@ class ModelRpcServer:
|
||||
)
|
||||
|
||||
# Truncate prompts that are too long
|
||||
req.input_ids = req.input_ids[: self.model_config.context_len - 1]
|
||||
req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
|
||||
req.sampling_params.max_new_tokens = min(
|
||||
req.sampling_params.max_new_tokens,
|
||||
self.model_config.context_len - 1 - len(req.input_ids),
|
||||
self.max_total_num_token - 128 - len(req.input_ids),
|
||||
self.model_config.context_len - 1 - len(req.origin_input_ids),
|
||||
self.max_total_num_token - 128 - len(req.origin_input_ids),
|
||||
)
|
||||
self.forward_queue.append(req)
|
||||
|
||||
@@ -292,6 +310,10 @@ class ModelRpcServer:
|
||||
|
||||
# Compute matched prefix length
|
||||
for req in self.forward_queue:
|
||||
assert (
|
||||
len(req.output_ids) == 0
|
||||
), "The output ids should be empty when prefilling"
|
||||
req.input_ids = req.origin_input_ids + req.prev_output_ids
|
||||
prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
|
||||
if req.return_logprob:
|
||||
prefix_indices = prefix_indices[: req.logprob_start_len]
|
||||
@@ -319,7 +341,7 @@ class ModelRpcServer:
|
||||
)
|
||||
|
||||
for req in self.forward_queue:
|
||||
if req.return_logprob:
|
||||
if req.return_logprob and req.normalized_prompt_logprob is None:
|
||||
# Need at least two tokens to compute normalized logprob
|
||||
if req.extend_input_len < 2:
|
||||
delta = 2 - req.extend_input_len
|
||||
@@ -441,28 +463,53 @@ class ModelRpcServer:
|
||||
req.check_finished()
|
||||
|
||||
if req.return_logprob:
|
||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||
if req.normalized_prompt_logprob is None:
|
||||
req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
|
||||
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
zip(
|
||||
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
if req.prefill_token_logprobs is None:
|
||||
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
||||
req.prefill_token_logprobs = list(
|
||||
zip(
|
||||
prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
|
||||
req.input_ids[-req.extend_input_len + 1 :],
|
||||
)
|
||||
)
|
||||
)
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
req.decode_token_logprobs = [
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_token_logprobs = [
|
||||
(None, req.input_ids[0])
|
||||
] + req.prefill_token_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_token_logprobs.extend(
|
||||
list(
|
||||
zip(
|
||||
prefill_token_logprobs[
|
||||
pt
|
||||
+ req.extend_input_len
|
||||
- req.last_update_decode_tokens : pt
|
||||
+ req.extend_input_len
|
||||
- 1
|
||||
],
|
||||
req.input_ids[-req.last_update_decode_tokens + 1 :],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
req.decode_token_logprobs.append(
|
||||
(last_token_logprobs[i], next_token_ids[i])
|
||||
]
|
||||
)
|
||||
|
||||
if req.top_logprobs_num > 0:
|
||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
req.decode_top_logprobs = [decode_top_logprobs[i]]
|
||||
if req.prefill_top_logprobs is None:
|
||||
req.prefill_top_logprobs = prefill_top_logprobs[i]
|
||||
if req.logprob_start_len == 0:
|
||||
req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
|
||||
|
||||
if req.last_update_decode_tokens != 0:
|
||||
req.decode_top_logprobs.extend(
|
||||
prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
||||
)
|
||||
req.decode_top_logprobs.append(decode_top_logprobs[i])
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
@@ -484,7 +531,7 @@ class ModelRpcServer:
|
||||
# check if decode out of memory
|
||||
if not batch.check_decode_mem():
|
||||
old_ratio = self.new_token_ratio
|
||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
|
||||
self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
|
||||
|
||||
retracted_reqs = batch.retract_decode()
|
||||
logger.info(
|
||||
@@ -495,26 +542,13 @@ class ModelRpcServer:
|
||||
self.forward_queue.extend(retracted_reqs)
|
||||
else:
|
||||
self.new_token_ratio = max(
|
||||
self.new_token_ratio - self.new_token_ratio_step[0],
|
||||
self.new_token_ratio - self.new_token_ratio_decay,
|
||||
self.min_new_token_ratio,
|
||||
)
|
||||
|
||||
if not self.disable_regex_jump_forward:
|
||||
# check for jump-forward
|
||||
jump_forward_reqs = batch.check_for_jump_forward()
|
||||
|
||||
# check for image jump-forward
|
||||
for req in jump_forward_reqs:
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.input_ids,
|
||||
req.image_offset,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
|
||||
|
||||
self.forward_queue.extend(jump_forward_reqs)
|
||||
if batch.is_empty():
|
||||
@@ -557,8 +591,8 @@ class ModelRpcServer:
|
||||
|
||||
def handle_finished_requests(self, batch: Batch):
|
||||
output_rids = []
|
||||
prev_output_strs = []
|
||||
output_tokens = []
|
||||
output_and_jump_forward_strs = []
|
||||
output_hit_stop_str = []
|
||||
output_skip_special_tokens = []
|
||||
output_spaces_between_special_tokens = []
|
||||
@@ -582,8 +616,8 @@ class ModelRpcServer:
|
||||
)
|
||||
):
|
||||
output_rids.append(req.rid)
|
||||
prev_output_strs.append(req.prev_output_str)
|
||||
output_tokens.append(req.output_ids)
|
||||
output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
|
||||
output_hit_stop_str.append(req.hit_stop_str)
|
||||
output_skip_special_tokens.append(
|
||||
req.sampling_params.skip_special_tokens
|
||||
@@ -593,10 +627,8 @@ class ModelRpcServer:
|
||||
)
|
||||
|
||||
meta_info = {
|
||||
"prompt_tokens": req.prompt_tokens,
|
||||
"completion_tokens": len(req.input_ids)
|
||||
+ len(req.output_ids)
|
||||
- req.prompt_tokens,
|
||||
"prompt_tokens": len(req.origin_input_ids),
|
||||
"completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
|
||||
"completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
|
||||
"finish_reason": FinishReason.to_str(req.finish_reason),
|
||||
"hit_stop_str": req.hit_stop_str,
|
||||
@@ -623,8 +655,8 @@ class ModelRpcServer:
|
||||
self.out_pyobjs.append(
|
||||
BatchTokenIDOut(
|
||||
output_rids,
|
||||
prev_output_strs,
|
||||
output_tokens,
|
||||
output_and_jump_forward_strs,
|
||||
output_hit_stop_str,
|
||||
output_skip_special_tokens,
|
||||
output_spaces_between_special_tokens,
|
||||
|
||||
Reference in New Issue
Block a user