Support Faster JSON decoding for llava (#137)
When sending fast-forwarded reqs to model_rpc, re-calculate `pad_input_ids`
This commit is contained in:
@@ -31,6 +31,7 @@ class Req:
|
||||
self.pixel_values = None
|
||||
self.image_size = None
|
||||
self.image_offset = 0
|
||||
self.pad_value = None
|
||||
|
||||
self.sampling_params = None
|
||||
self.return_logprob = False
|
||||
@@ -58,7 +59,7 @@ class Req:
|
||||
def max_new_tokens(self):
|
||||
return self.sampling_params.max_new_tokens
|
||||
|
||||
def tokenize_fast_forward(self, fast_forward_str, next_state):
|
||||
def fast_forward_and_retokenize(self, fast_forward_str, next_state):
|
||||
old_output_str = self.tokenizer.decode(self.output_ids)
|
||||
# FIXME: This logic does not really solve the problem of determining whether
|
||||
# there should be a leading space.
|
||||
@@ -75,9 +76,14 @@ class Req:
|
||||
+ fast_forward_str
|
||||
)
|
||||
new_input_ids = self.tokenizer.encode(new_input_string)
|
||||
fast_forward_tokens_len = (
|
||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||
)
|
||||
if self.pixel_values is not None:
|
||||
# NOTE: This is a hack because the old input_ids contains the image padding
|
||||
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
|
||||
else:
|
||||
fast_forward_tokens_len = (
|
||||
len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
|
||||
)
|
||||
|
||||
# print("=" * 100)
|
||||
# print(f"Catch fast forward:\n{fast_forward_str}")
|
||||
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
|
||||
@@ -351,7 +357,7 @@ class Batch:
|
||||
self.tree_cache.dec_ref_counter(req.last_node)
|
||||
|
||||
# fast forward
|
||||
req.tokenize_fast_forward(fast_forward_str, next_state)
|
||||
req.fast_forward_and_retokenize(fast_forward_str, next_state)
|
||||
|
||||
fast_forward_reqs.append(req)
|
||||
filter_indices.remove(i)
|
||||
|
||||
@@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service):
|
||||
self.max_num_running_seq = self.max_total_num_token // 2
|
||||
self.max_prefill_num_token = max(
|
||||
self.model_config.context_len,
|
||||
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token,
|
||||
self.max_total_num_token // 6
|
||||
if server_args.max_prefill_num_token is None
|
||||
else server_args.max_prefill_num_token,
|
||||
)
|
||||
self.int_token_logit_bias = torch.tensor(
|
||||
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
||||
@@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
|
||||
req.pixel_values = recv_req.pixel_values
|
||||
if req.pixel_values is not None:
|
||||
pad_value = [
|
||||
req.pad_value = [
|
||||
(recv_req.image_hash) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 16) % self.model_config.vocab_size,
|
||||
(recv_req.image_hash >> 32) % self.model_config.vocab_size,
|
||||
@@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service):
|
||||
]
|
||||
req.image_size = recv_req.image_size
|
||||
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids, pad_value, req.pixel_values.shape, req.image_size
|
||||
req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
|
||||
)
|
||||
req.sampling_params = recv_req.sampling_params
|
||||
req.return_logprob = recv_req.return_logprob
|
||||
@@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service):
|
||||
if not self.no_regex_fast_forward:
|
||||
# check for fast forward
|
||||
fast_forward_reqs = batch.check_for_fast_forward()
|
||||
|
||||
# check for image fast forward
|
||||
for req in fast_forward_reqs:
|
||||
if req.pixel_values is not None:
|
||||
(
|
||||
req.input_ids,
|
||||
req.image_offset,
|
||||
) = self.model_runner.model.pad_input_ids(
|
||||
req.input_ids,
|
||||
req.pad_value,
|
||||
req.pixel_values.shape,
|
||||
req.image_size,
|
||||
)
|
||||
|
||||
self.forward_queue.extend(fast_forward_reqs)
|
||||
if batch.is_empty():
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user