From fd9ad817ec449592ec58b1cb7b57ac2e55d49b02 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 28 Sep 2024 23:28:55 -0700 Subject: [PATCH] Organize image inputs (#1531) --- python/sglang/srt/managers/io_struct.py | 10 +--- python/sglang/srt/managers/schedule_batch.py | 51 +++++++++++++----- .../sglang/srt/managers/tokenizer_manager.py | 37 ++++++------- python/sglang/srt/managers/tp_worker.py | 32 ++++------- .../srt/model_executor/forward_batch_info.py | 13 ++--- .../sglang/srt/model_executor/model_runner.py | 15 +----- python/sglang/srt/models/llava.py | 54 ++++++++++--------- python/sglang/srt/models/llavavid.py | 41 +++++++------- 8 files changed, 121 insertions(+), 132 deletions(-) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 0c7a57f46..c26a65f74 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -172,12 +172,8 @@ class TokenizedGenerateReqInput: input_text: str # The input token ids input_ids: List[int] - # The pixel values for input images - pixel_values: List[float] - # The hash values of input images - image_hashes: List[int] - # The image sizes - image_sizes: List[List[int]] + # The image input + image_inputs: dict # The sampling parameters sampling_params: SamplingParams # Whether to return the logprobs @@ -188,8 +184,6 @@ class TokenizedGenerateReqInput: top_logprobs_num: int # Whether to stream output stream: bool - # Modalities of the input images - modalites: Optional[List[str]] = None # LoRA related lora_path: Optional[str] = None # None means just use the base model diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c4c91c711..bb4785981 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason): } +@dataclass +class ImageInputs: + pixel_values: torch.Tensor + image_hash: int + image_sizes: Optional[list] = None + image_offsets: Optional[list] = None + pad_values: Optional[list] = None + modalities: Optional[list] = None + + image_embeds: Optional[List[torch.Tensor]] = None + aspect_ratio_ids: Optional[List[torch.Tensor]] = None + aspect_ratio_mask: Optional[List[torch.Tensor]] = None + + @staticmethod + def from_dict(obj, vocab_size): + # Use image hash as fake token_ids, which is then used for prefix matching + ret = ImageInputs( + pixel_values=obj["pixel_values"], + image_hash=hash(tuple(obj["image_hashes"])), + ) + image_hash = ret.image_hash + ret.pad_values = [ + (image_hash) % vocab_size, + (image_hash >> 16) % vocab_size, + (image_hash >> 32) % vocab_size, + (image_hash >> 64) % vocab_size, + ] + ret.image_sizes = obj["image_sizes"] + # Only when pixel values is not None we have modalities + ret.modalities = obj["modalities"] + return ret + + class Req: """Store all inforamtion of a request.""" @@ -147,11 +180,7 @@ class Req: self.completion_tokens_wo_jump_forward = 0 # For vision inputs - self.pixel_values = None - self.image_sizes = None - self.image_offsets = None - self.pad_value = None - self.modalities = None + self.image_inputs: Optional[ImageInputs] = None # Prefix info self.prefix_indices = [] @@ -654,15 +683,9 @@ class ScheduleBatch: self.tree_cache.cache_finished_req(req, cur_all_ids) # re-applying image padding - if req.pixel_values is not None: - ( - req.origin_input_ids, - req.image_offsets, - ) = model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values, - req.image_sizes, + if req.image_inputs is not None: + req.origin_input_ids = model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, req.image_inputs ) jump_forward_reqs.append(req) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index b93ceb3a6..e40096fe7 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -194,10 +194,9 @@ class TokenizerManager: ) if self.is_generation: - pixel_values, image_hashes, image_sizes = await self._get_pixel_values( - obj.image_data if not_use_index else obj.image_data[index] + image_inputs = await self._get_image_inputs( + obj, obj.image_data if not_use_index else obj.image_data[index] ) - modalities = obj.modalities return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] ) @@ -248,10 +247,7 @@ class TokenizerManager: sampling_params = SamplingParams(**obj.sampling_params[0]) sampling_params.max_new_tokens = 0 - pixel_values, image_hashes, image_sizes = await self._get_pixel_values( - obj.image_data[0] - ) - modalities = obj.modalities + image_inputs = await self._get_image_inputs(obj, obj.image_data[0]) return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] @@ -262,15 +258,12 @@ class TokenizerManager: rid, input_text, input_ids, - pixel_values, - image_hashes, - image_sizes, + image_inputs, sampling_params, return_logprob, logprob_start_len, top_logprobs_num, obj.stream, - modalities, ( obj.lora_path[index] if isinstance(obj.lora_path, list) @@ -369,24 +362,20 @@ class TokenizerManager: sampling_params = self._get_sampling_params(obj.sampling_params[index]) if self.is_generation: - pixel_values, image_hashes, image_sizes = ( - await self._get_pixel_values(obj.image_data[index]) + image_inputs = await self._get_image_inputs( + obj, obj.image_data[index] ) - modalities = obj.modalities tokenized_obj = TokenizedGenerateReqInput( rid, input_text, input_ids, - pixel_values, - image_hashes, - image_sizes, + image_inputs, sampling_params, obj.return_logprob[index], obj.logprob_start_len[index], obj.top_logprobs_num[index], obj.stream, - modalities, ( obj.lora_path[index] if isinstance(obj.lora_path, list) @@ -697,10 +686,11 @@ class TokenizerManager: ) return top_logprobs - async def _get_pixel_values(self, image_data: List[Union[str, bytes]]): + async def _get_image_inputs(self, obj, image_data: List[Union[str, bytes]]): if not image_data: - return None, None, None + return None + # TODO: move this into a processor for each vision architecture aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) grid_pinpoints = ( self.hf_config.image_grid_pinpoints @@ -741,7 +731,12 @@ class TokenizerManager: else: raise ValueError(f"Invalid image data: {image_data}") - return pixel_values, image_hashes, image_sizes + return { + "pixel_values": pixel_values, + "image_hashes": image_hashes, + "image_sizes": image_sizes, + "modalities": obj.modalities, + } async def _process_single_image( self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index b96906700..02fb87158 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder from sglang.srt.managers.schedule_batch import ( FINISH_ABORT, BaseFinishReason, + ImageInputs, Req, ScheduleBatch, ) @@ -340,29 +341,16 @@ class ModelTpServer: req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - # Use image hash as fake token_ids, which is then used - # for prefix matching - image_hash = hash(tuple(recv_req.image_hashes)) - req.pad_value = [ - (image_hash) % self.model_config.vocab_size, - (image_hash >> 16) % self.model_config.vocab_size, - (image_hash >> 32) % self.model_config.vocab_size, - (image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_sizes = recv_req.image_sizes - ( - req.origin_input_ids, - req.image_offsets, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values, - req.image_sizes, + + # Image inputs + if recv_req.image_inputs is not None: + req.image_inputs = ImageInputs.from_dict( + recv_req.image_inputs, self.model_config.vocab_size ) - # Only when pixel values is not None we have modalities - req.modalities = recv_req.modalites + req.origin_input_ids = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, req.image_inputs + ) + req.return_logprob = recv_req.return_logprob req.top_logprobs_num = recv_req.top_logprobs_num req.stream = recv_req.stream diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4e81abec1..8421774f1 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -25,7 +25,7 @@ import torch if TYPE_CHECKING: from sglang.srt.layers.attention_backend import AttentionBackend - from sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.managers.schedule_batch import ImageInputs, ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -84,17 +84,10 @@ class InputMetadata: extend_logprob_start_lens_cpu: List[int] = None # For multimodal - pixel_values: List[torch.Tensor] = None - image_sizes: List[List[List[int]]] = None - image_offsets: List[List[int]] = None - modalities: List[List[str]] = None + image_inputs: List[ImageInputs] = None def init_multimuldal_info(self, batch: ScheduleBatch): - reqs = batch.reqs - self.pixel_values = [r.pixel_values for r in reqs] - self.image_sizes = [r.image_sizes for r in reqs] - self.image_offsets = [r.image_offsets for r in reqs] - self.modalities = [r.modalities for r in reqs] + self.image_inputs = [r.image_inputs for r in batch.reqs] def compute_positions(self, batch: ScheduleBatch): if self.forward_mode.is_decode(): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index a88f06ed6..afebd4f88 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -498,23 +498,10 @@ class ModelRunner: get_embedding=True, ) - def forward_extend_multi_modal(self, batch: ScheduleBatch): - input_metadata = InputMetadata.from_schedule_batch(self, batch) - return self.model.forward( - batch.input_ids, - input_metadata.positions, - input_metadata, - input_metadata.pixel_values, - input_metadata.image_sizes, - input_metadata.image_offsets, - ) - def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]: assert batch.forward_mode is not None - if self.is_multimodal_model and batch.forward_mode.is_extend(): - return self.forward_extend_multi_modal(batch) - elif batch.forward_mode.is_decode(): + if batch.forward_mode.is_decode(): return self.forward_decode(batch) elif batch.forward_mode.is_extend(): return self.forward_extend(batch) diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index df62b39fc..1d8a3f40f 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -35,25 +35,22 @@ from vllm.config import CacheConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.mm_utils import ( get_anyres_image_grid_shape, unpad_image, unpad_image_shape, ) -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM class LlavaBaseForCausalLM(nn.Module): - def pad_input_ids( - self, - input_ids: List[int], - pad_value: List[int], - pixel_values: List, - image_sizes: List[List[int]], - ): + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values + # hardcode for spatial_unpad + anyres image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" offset_list = [] @@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module): new_w = int(new_w // times) new_image_feature_len += new_h * (new_w + 1) - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) + pad_ids = pad_values * ( + (new_image_feature_len + len(pad_values)) // len(pad_values) ) # print("calculated new_image_feature_len: ", new_image_feature_len) try: @@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module): + input_ids[offset + 1 :] ) offset_list.append(offset) - return input_ids, offset_list + + image_inputs.image_offsets = offset_list + return input_ids def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module): input_ids: torch.LongTensor, positions: torch.Tensor, input_metadata: InputMetadata, - pixel_values: Optional[List[Optional[np.array]]] = None, - image_sizes: Optional[List[List[int]]] = None, - image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: + image_inputs = input_metadata.image_inputs + if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size # Got List[List[str]] extend it to List[str] # The length of the List should be equal to batch size modalities_list = [] - for modalities in input_metadata.modalities: - if modalities is not None: - modalities_list.extend(modalities) + max_image_offset = [] + for im in image_inputs: + if im and im.modalities is not None: + modalities_list.extend(im.modalities) + if im and im.image_offsets is not None: + max_image_offset.append(max(im.image_offsets)) + else: + max_image_offset.append(-1) # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) - # Whether the requests need vision inputs - max_image_offset = np.array( - [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] - ) start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() - need_vision = start_positions <= max_image_offset + need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): - pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] + pixel_values = [ + image_inputs[i].pixel_values for i in range(bs) if need_vision[i] + ] + image_sizes = [ + image_inputs[i].image_sizes for i in range(bs) if need_vision[i] + ] + image_offsets = [ + image_inputs[i].image_offsets for i in range(bs) if need_vision[i] + ] ########## Encode Image ######## diff --git a/python/sglang/srt/models/llavavid.py b/python/sglang/srt/models/llavavid.py index 408a90f19..4613c208f 100644 --- a/python/sglang/srt/models/llavavid.py +++ b/python/sglang/srt/models/llavavid.py @@ -26,7 +26,8 @@ from vllm.config import CacheConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama import LlamaForCausalLM @@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module): torch.empty(config.text_config.hidden_size, dtype=torch.float16) ) - def pad_input_ids( - self, - input_ids: List[int], - pad_value: List[int], - pixel_values: List, - image_sizes: List[List[int]], - ): + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + pad_values = image_inputs.pad_values new_image_feature_len = self.image_feature_len - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) + pad_ids = pad_values * ( + (new_image_feature_len + len(pad_values)) // len(pad_values) ) offset = input_ids.index(self.config.image_token_index) # old_len + pad_len - 1, because we need to remove image_token_id @@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module): + pad_ids[:new_image_feature_len] + input_ids[offset + 1 :] ) - return new_input_ids, [offset] + image_inputs.image_offsets = [offset] + return new_input_ids def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module): input_ids: torch.LongTensor, positions: torch.Tensor, input_metadata: InputMetadata, - pixel_values: Optional[List[Optional[np.array]]] = None, - image_sizes: Optional[List[List[int]]] = None, - image_offsets: Optional[List[int]] = None, ) -> torch.Tensor: + image_inputs = input_metadata.image_inputs if input_metadata.forward_mode.is_extend(): bs = input_metadata.batch_size @@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module): input_embeds = self.language_model.model.embed_tokens(input_ids) # Whether the requests need vision inputs - max_image_offset = np.array( - [max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)] - ) + max_image_offset = [] + for im in image_inputs: + if im and im.image_offsets: + max_image_offset.append(max(im.image_offsets)) + else: + max_image_offset.append(-1) start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() - need_vision = start_positions <= max_image_offset + need_vision = start_positions <= np.array(max_image_offset) if need_vision.any(): - pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] + pixel_values = [ + image_inputs[i].pixel_values for i in range(bs) if need_vision[i] + ] + image_offsets = [ + image_inputs[i].image_offsets for i in range(bs) if need_vision[i] + ] ########## Encode Image ########