Support Faster JSON decoding for llava (#137)
When sending fast-forwarded reqs to model_rpc, re-calculate `pad_input_ids`
This commit is contained in:
@@ -31,6 +31,7 @@ class Req:
|
|||||||
self.pixel_values = None
|
self.pixel_values = None
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
self.image_offset = 0
|
self.image_offset = 0
|
||||||
|
self.pad_value = None
|
||||||
|
|
||||||
self.sampling_params = None
|
self.sampling_params = None
|
||||||
self.return_logprob = False
|
self.return_logprob = False
|
||||||
@@ -58,7 +59,7 @@ class Req:
|
|||||||
def max_new_tokens(self):
|
def max_new_tokens(self):
|
||||||
return self.sampling_params.max_new_tokens
|
return self.sampling_params.max_new_tokens
|
||||||
|
|
||||||
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
def fast_forward_and_retokenize(self, fast_forward_str, next_state):
|
||||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||||
# FIXME: This logic does not really solve the problem of determining whether
|
# FIXME: This logic does not really solve the problem of determining whether
|
||||||
# there should be a leading space.
|
# there should be a leading space.
|
||||||
@@ -75,9 +76,14 @@ class Req:
|
|||||||
+ fast_forward_str
|
+ fast_forward_str
|
||||||
)
|
)
|
||||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||||
fast_forward_tokens_len = (
|
if self.pixel_values is not None:
|
||||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
# NOTE: This is a hack because the old input_ids contains the image padding
|
||||||
)
|
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
|
||||||
|
else:
|
||||||
|
fast_forward_tokens_len = (
|
||||||
|
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||||
|
)
|
||||||
|
|
||||||
# print("=" * 100)
|
# print("=" * 100)
|
||||||
# print(f"Catch fast forward:\n{fast_forward_str}")
|
# print(f"Catch fast forward:\n{fast_forward_str}")
|
||||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||||
@@ -351,7 +357,7 @@ class Batch:
|
|||||||
self.tree_cache.dec_ref_counter(req.last_node)
|
self.tree_cache.dec_ref_counter(req.last_node)
|
||||||
|
|
||||||
# fast forward
|
# fast forward
|
||||||
req.tokenize_fast_forward(fast_forward_str, next_state)
|
req.fast_forward_and_retokenize(fast_forward_str, next_state)
|
||||||
|
|
||||||
fast_forward_reqs.append(req)
|
fast_forward_reqs.append(req)
|
||||||
filter_indices.remove(i)
|
filter_indices.remove(i)
|
||||||
|
|||||||
@@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
self.max_num_running_seq = self.max_total_num_token // 2
|
self.max_num_running_seq = self.max_total_num_token // 2
|
||||||
self.max_prefill_num_token = max(
|
self.max_prefill_num_token = max(
|
||||||
self.model_config.context_len,
|
self.model_config.context_len,
|
||||||
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
|
self.max_total_num_token // 6
|
||||||
|
if server_args.max_prefill_num_token is None
|
||||||
|
else server_args.max_prefill_num_token,
|
||||||
)
|
)
|
||||||
self.int_token_logit_bias = torch.tensor(
|
self.int_token_logit_bias = torch.tensor(
|
||||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||||
@@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||||
req.pixel_values = recv_req.pixel_values
|
req.pixel_values = recv_req.pixel_values
|
||||||
if req.pixel_values is not None:
|
if req.pixel_values is not None:
|
||||||
pad_value = [
|
req.pad_value = [
|
||||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||||
@@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
]
|
]
|
||||||
req.image_size = recv_req.image_size
|
req.image_size = recv_req.image_size
|
||||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||||
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
|
||||||
)
|
)
|
||||||
req.sampling_params = recv_req.sampling_params
|
req.sampling_params = recv_req.sampling_params
|
||||||
req.return_logprob = recv_req.return_logprob
|
req.return_logprob = recv_req.return_logprob
|
||||||
@@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service):
|
|||||||
if not self.no_regex_fast_forward:
|
if not self.no_regex_fast_forward:
|
||||||
# check for fast forward
|
# check for fast forward
|
||||||
fast_forward_reqs = batch.check_for_fast_forward()
|
fast_forward_reqs = batch.check_for_fast_forward()
|
||||||
|
|
||||||
|
# check for image fast forward
|
||||||
|
for req in fast_forward_reqs:
|
||||||
|
if req.pixel_values is not None:
|
||||||
|
(
|
||||||
|
req.input_ids,
|
||||||
|
req.image_offset,
|
||||||
|
) = self.model_runner.model.pad_input_ids(
|
||||||
|
req.input_ids,
|
||||||
|
req.pad_value,
|
||||||
|
req.pixel_values.shape,
|
||||||
|
req.image_size,
|
||||||
|
)
|
||||||
|
|
||||||
self.forward_queue.extend(fast_forward_reqs)
|
self.forward_queue.extend(fast_forward_reqs)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user