diff --git a/python/sglang/srt/managers/router/infer_batch.py b/python/sglang/srt/managers/router/infer_batch.py index 88f6031f7..5a3cc0897 100644 --- a/python/sglang/srt/managers/router/infer_batch.py +++ b/python/sglang/srt/managers/router/infer_batch.py @@ -31,6 +31,7 @@ class Req: self.pixel_values = None self.image_size = None self.image_offset = 0 + self.pad_value = None self.sampling_params = None self.return_logprob = False @@ -58,7 +59,7 @@ class Req: def max_new_tokens(self): return self.sampling_params.max_new_tokens - def tokenize_fast_forward(self, fast_forward_str, next_state): + def fast_forward_and_retokenize(self, fast_forward_str, next_state): old_output_str = self.tokenizer.decode(self.output_ids) # FIXME: This logic does not really solve the problem of determining whether # there should be a leading space. @@ -75,9 +76,14 @@ class Req: + fast_forward_str ) new_input_ids = self.tokenizer.encode(new_input_string) - fast_forward_tokens_len = ( - len(new_input_ids) - len(self.input_ids) - len(self.output_ids) - ) + if self.pixel_values is not None: + # NOTE: This is a hack because the old input_ids contains the image padding + fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str)) + else: + fast_forward_tokens_len = ( + len(new_input_ids) - len(self.input_ids) - len(self.output_ids) + ) + # print("=" * 100) # print(f"Catch fast forward:\n{fast_forward_str}") # print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) @@ -351,7 +357,7 @@ class Batch: self.tree_cache.dec_ref_counter(req.last_node) # fast forward - req.tokenize_fast_forward(fast_forward_str, next_state) + req.fast_forward_and_retokenize(fast_forward_str, next_state) fast_forward_reqs.append(req) filter_indices.remove(i) diff --git a/python/sglang/srt/managers/router/model_rpc.py b/python/sglang/srt/managers/router/model_rpc.py index 8b7adf944..49da99d96 100644 --- a/python/sglang/srt/managers/router/model_rpc.py +++ b/python/sglang/srt/managers/router/model_rpc.py @@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service): self.max_num_running_seq = self.max_total_num_token // 2 self.max_prefill_num_token = max( self.model_config.context_len, - self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token, + self.max_total_num_token // 6 + if server_args.max_prefill_num_token is None + else server_args.max_prefill_num_token, ) self.int_token_logit_bias = torch.tensor( get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) @@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.pixel_values = recv_req.pixel_values if req.pixel_values is not None: - pad_value = [ + req.pad_value = [ (recv_req.image_hash) % self.model_config.vocab_size, (recv_req.image_hash >> 16) % self.model_config.vocab_size, (recv_req.image_hash >> 32) % self.model_config.vocab_size, @@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service): ] req.image_size = recv_req.image_size req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( - req.input_ids, pad_value, req.pixel_values.shape, req.image_size + req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size ) req.sampling_params = recv_req.sampling_params req.return_logprob = recv_req.return_logprob @@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service): if not self.no_regex_fast_forward: # check for fast forward fast_forward_reqs = batch.check_for_fast_forward() + + # check for image fast forward + for req in fast_forward_reqs: + if req.pixel_values is not None: + ( + req.input_ids, + req.image_offset, + ) = self.model_runner.model.pad_input_ids( + req.input_ids, + req.pad_value, + req.pixel_values.shape, + req.image_size, + ) + self.forward_queue.extend(fast_forward_reqs) if batch.is_empty(): return