diff --git a/examples/usage/llava/http_llama3_llava_test.py b/examples/usage/llava/http_llama3_llava_test.py index 2f95c6542..113adbc8d 100644 --- a/examples/usage/llava/http_llama3_llava_test.py +++ b/examples/usage/llava/http_llama3_llava_test.py @@ -22,11 +22,7 @@ import aiohttp import requests from llava.conversation import ( - default_conversation, - conv_templates, - SeparatorStyle, conv_llava_llama_3, - conv_qwen, ) @@ -43,7 +39,8 @@ async def test_concurrent(args): prompt = "\nPlease generate caption towards this image." conv_template = copy.deepcopy(conv_llava_llama_3) - conv_template.append_message(role="user", message=prompt) + conv_template.append_message(role=conv_template.roles[0], message=prompt) + conv_template.append_message(role=conv_template.roles[1], message=None) prompt_with_template = conv_template.get_prompt() response = [] for i in range(1): @@ -74,7 +71,8 @@ def test_streaming(args): url = f"{args.host}:{args.port}" prompt = "\nPlease generate caption towards this image." conv_template = copy.deepcopy(conv_llava_llama_3) - conv_template.append_message(role="user", message=prompt) + conv_template.append_message(role=conv_template.roles[0], message=prompt) + conv_template.append_message(role=conv_template.roles[1], message=None) prompt_with_template = conv_template.get_prompt() pload = { "text": prompt_with_template, diff --git a/examples/usage/llava/http_qwen_llava_test.py b/examples/usage/llava/http_qwen_llava_test.py index 1495c0dfb..9ba206415 100644 --- a/examples/usage/llava/http_qwen_llava_test.py +++ b/examples/usage/llava/http_qwen_llava_test.py @@ -22,11 +22,7 @@ import aiohttp import requests from llava.conversation import ( - default_conversation, - conv_templates, - SeparatorStyle, - conv_llava_llama_3, - conv_qwen, + conv_qwen ) @@ -43,7 +39,8 @@ async def test_concurrent(args): prompt = "\nPlease generate caption towards this image." conv_template = copy.deepcopy(conv_qwen) - conv_template.append_message(role="user", message=prompt) + conv_template.append_message(role=conv_template.roles[0], message=prompt) + conv_template.append_message(role=conv_template.roles[1], message=None) prompt_with_template = conv_template.get_prompt() response = [] for i in range(1): @@ -74,7 +71,8 @@ def test_streaming(args): url = f"{args.host}:{args.port}" prompt = "\nPlease generate caption towards this image." conv_template = copy.deepcopy(conv_qwen) - conv_template.append_message(role="user", message=prompt) + conv_template.append_message(role=conv_template.roles[0], message=prompt) + conv_template.append_message(role=conv_template.roles[1], message=None) prompt_with_template = conv_template.get_prompt() pload = { "text": prompt_with_template, @@ -113,5 +111,5 @@ if __name__ == "__main__": parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--port", type=int, default=30000) args = parser.parse_args() - # asyncio.run(test_concurrent(args)) + asyncio.run(test_concurrent(args)) test_streaming(args) diff --git a/python/sglang/srt/managers/router/model_runner.py b/python/sglang/srt/managers/router/model_runner.py index 34b160789..faa7ab927 100644 --- a/python/sglang/srt/managers/router/model_runner.py +++ b/python/sglang/srt/managers/router/model_runner.py @@ -421,7 +421,12 @@ def import_model_classes(): if not ispkg: module = importlib.import_module(name) if hasattr(module, "EntryClass"): - model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass + entry = module.EntryClass + if isinstance(entry, list): # To support multiple model classes in one module + for cls in entry: + model_arch_name_to_cls[cls.__name__] = cls + else: + model_arch_name_to_cls[entry.__name__] = entry return model_arch_name_to_cls diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 5d4726b54..4755939b7 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -5,7 +5,7 @@ from typing import List, Iterable, Optional, Tuple import numpy as np import torch from torch import nn -from transformers import CLIPVisionModel, LlavaConfig +from transformers import CLIPVisionModel, CLIPVisionConfig, LlavaConfig, Qwen2Config, MistralConfig from transformers.models.llava.modeling_llava import LlavaMultiModalProjector from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -18,6 +18,8 @@ from sglang.srt.mm_utils import ( unpad_image_shape, ) from sglang.srt.models.llama2 import LlamaForCausalLM +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.mistral import MistralForCausalLM class LlavaLlamaForCausalLM(nn.Module): @@ -287,8 +289,101 @@ class LlavaLlamaForCausalLM(nn.Module): return self.image_size // self.patch_size -first_call = True +class LlavaQwenForCausalLM(LlavaLlamaForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config, quant_config=quant_config) + self.config = config + self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: + self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) + if getattr(self.config, "text_config", None) is None: + self.config.text_config = Qwen2Config(self.config._name_or_path) + + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + if getattr(self.config, "projector_hidden_act", None) is None: + self.config.projector_hidden_act = "gelu" + + if getattr(self.config, "image_token_index", None) is None: + self.config.image_token_index = 151646 + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): + new_image_feature_len = self.image_feature_len + # now only support spatial_unpad + anyres + if self.mm_patch_merge_type.startswith("spatial"): + height = width = self.num_patches_per_side + if pt_shape[0] > 1: + if self.image_aspect_ratio == "anyres": + num_patch_width, num_patch_height = get_anyres_image_grid_shape( + image_size, + self.image_grid_pinpoints, + self.vision_tower.config.image_size, + ) + if "unpad" in self.mm_patch_merge_type: + h = num_patch_height * height + w = num_patch_width * width + new_h, new_w = unpad_image_shape(h, w, image_size) + new_image_feature_len += new_h * (new_w + 1) + + pad_ids = pad_value * ( + (new_image_feature_len + len(pad_value)) // len(pad_value) + ) + offset = input_ids.index(self.config.image_token_index) + # old_len + pad_len - 1, because we need to remove image_token_id + new_input_ids = ( + input_ids[:offset] + + pad_ids[:new_image_feature_len] + + input_ids[offset + 1 :] + ) + return new_input_ids, offset + + +class LlavaMistralForCausalLM(LlavaLlamaForCausalLM): + def __init__( + self, + config: LlavaConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__(config, quant_config=quant_config) + self.config = config + self.vision_tower = None + if getattr(self.config, "vision_config", None) is None: + self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) + + if getattr(self.config, "text_config", None) is None: + self.config.text_config = MistralConfig(self.config._name_or_path) + + self.config.vision_config.hidden_size = config.mm_hidden_size + self.config.text_config.hidden_size = config.hidden_size + + if getattr(self.config, "projector_hidden_act", None) is None: + self.config.projector_hidden_act = "gelu" + + if getattr(self.config, "image_token_index", None) is None: + self.config.image_token_index = 32000 + + self.multi_modal_projector = LlavaMultiModalProjector(config) + self.language_model = MistralForCausalLM(config, quant_config=quant_config) + if "unpad" in getattr(config, "mm_patch_merge_type", ""): + self.language_model.model.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size, dtype=torch.float16) + ) + + +first_call = True def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0] @@ -319,4 +414,8 @@ def monkey_path_clip_vision_embed_forward(): ) -EntryClass = LlavaLlamaForCausalLM +EntryClass = [ + LlavaLlamaForCausalLM, + LlavaQwenForCausalLM, + LlavaMistralForCausalLM +] diff --git a/python/sglang/srt/models/llava_mistral.py b/python/sglang/srt/models/llava_mistral.py deleted file mode 100644 index 10531f84f..000000000 --- a/python/sglang/srt/models/llava_mistral.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Inference-only LLaVa model compatible with HuggingFace weights.""" - -from typing import List, Iterable, Optional, Tuple - -import numpy as np -import torch -from torch import nn -from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, MistralConfig -from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - -from sglang.srt.managers.router.infer_batch import ForwardMode -from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.mm_utils import ( - get_anyres_image_grid_shape, - unpad_image, - unpad_image_shape, -) -from sglang.srt.models.mistral import MistralForCausalLM - - -class LlavaMistralForCausalLM(nn.Module): - def __init__( - self, - config: LlavaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.vision_tower = None - if getattr(self.config, "vision_config", None) is None: - self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - - if getattr(self.config, "text_config", None) is None: - self.config.text_config = MistralConfig(self.config._name_or_path) - - self.config.vision_config.hidden_size = config.mm_hidden_size - self.config.text_config.hidden_size = config.hidden_size - - if getattr(self.config, "projector_hidden_act", None) is None: - self.config.projector_hidden_act = "gelu" - - if getattr(self.config, "image_token_index", None) is None: - self.config.image_token_index = 32000 - - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = MistralForCausalLM(config, quant_config=quant_config) - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.language_model.model.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size, dtype=torch.float16) - ) - - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - if self.mm_patch_merge_type.startswith("spatial"): - height = width = self.num_patches_per_side - if pt_shape[0] > 1: - if self.image_aspect_ratio == "anyres": - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, - self.image_grid_pinpoints, - self.vision_tower.config.image_size, - ) - if "unpad" in self.mm_patch_merge_type: - h = num_patch_height * height - w = num_patch_width * width - new_h, new_w = unpad_image_shape(h, w, image_size) - new_image_feature_len += new_h * (new_w + 1) - - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) - ) - offset = input_ids.index(self.config.image_token_index) - # old_len + pad_len - 1, because we need to remove image_token_id - new_input_ids = ( - input_ids[:offset] - + pad_ids[:new_image_feature_len] - + input_ids[offset + 1 :] - ) - return new_input_ids, offset - - def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. - - selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy in ["default", "patch"]: - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) - image_features = self.multi_modal_projector(selected_image_feature) - - return image_features - - def forward( - self, - input_ids: torch.LongTensor, - positions: torch.Tensor, - input_metadata: InputMetadata, - pixel_values: Optional[List[Optional[np.array]]] = None, - image_sizes: Optional[List[List[int]]] = None, - image_offsets: Optional[List[int]] = None, - ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: - bs = input_metadata.batch_size - - # Embed text input - input_embeds = self.language_model.model.embed_tokens(input_ids) - - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() - ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel - - if need_vision.any(): - pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] - - ########## Encode Image ######## - - if pixel_values[0].ndim == 4: - # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images - np.concatenate(pixel_values, axis=0) - # ndim=4 - concat_images = torch.tensor( - np.concatenate(pixel_values, axis=0), - device=self.vision_tower.device, - ) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - # hd image_features: BS, num_patch, 576, 4096 - else: - # normal pixel: BS, C=3, H=336, W=336 - pixel_values = torch.tensor( - np.array(pixel_values), device=self.vision_tower.device - ) - image_features = self.encode_images(pixel_values) - # image_features: BS, 576, 4096 - - if self.mm_patch_merge_type.startswith("spatial"): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.num_patches_per_side - assert height * width == base_image_feature.shape[0] - if self.image_aspect_ratio == "anyres": - ( - num_patch_width, - num_patch_height, - ) = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.image_grid_pinpoints, - self.vision_tower.config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - else: - raise NotImplementedError() - if "unpad" in self.mm_patch_merge_type: - image_feature = image_feature.permute( - 4, 0, 2, 1, 3 - ).contiguous() - image_feature = image_feature.flatten(1, 2).flatten( - 2, 3 - ) - image_feature = unpad_image( - image_feature, image_sizes[image_idx] - ) - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[ - :, None, None - ].expand(*image_feature.shape[:-1], 1), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose( - 0, 1 - ) - else: - image_feature = image_feature.permute( - 0, 2, 1, 3, 4 - ).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - if "unpad" in self.mm_patch_merge_type: - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[None], - ), - dim=0, - ) - new_image_features.append(image_feature) - image_features = new_image_features - - extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() - pt = 0 - for i in range(bs): - if not need_vision[i]: - continue - - start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) - pt += 1 - - return self.language_model( - input_ids, positions, input_metadata, input_embeds=input_embeds - ) - elif input_metadata.forward_mode == ForwardMode.DECODE: - return self.language_model(input_ids, positions, input_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir - vision_path = self.config.mm_vision_tower - self.vision_tower = CLIPVisionModel.from_pretrained( - vision_path, torch_dtype=torch.float16 - ).cuda() - self.vision_tower.eval() - - self.vision_feature_layer = self.config.mm_vision_select_layer - self.vision_feature_select_strategy = self.config.mm_vision_select_feature - self.image_size = self.vision_tower.config.image_size - self.patch_size = self.vision_tower.config.patch_size - - self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") - self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") - self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) - - self.image_feature_len = int((self.image_size / self.patch_size) ** 2) - if self.vision_feature_select_strategy == "patch": - pass - elif self.vision_feature_select_strategy == "cls_patch": - self.image_feature_len += 1 - else: - raise ValueError(f"Unexpected select feature: {self.select_feature}") - - # load mm_projector - projector_weights = { - "model.mm_projector.0": "multi_modal_projector.linear_1", - "model.mm_projector.2": "multi_modal_projector.linear_2", - "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). - } - params_dict = dict(self.named_parameters()) - weights = list(weights) - for name, loaded_weight in weights: - # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: - for weight_name, param_name in projector_weights.items(): - if weight_name in name: - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) - - monkey_path_clip_vision_embed_forward() - - @property - def num_patches_per_side(self): - return self.image_size // self.patch_size - - -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - -EntryClass = LlavaMistralForCausalLM diff --git a/python/sglang/srt/models/llava_qwen.py b/python/sglang/srt/models/llava_qwen.py deleted file mode 100644 index b524ef037..000000000 --- a/python/sglang/srt/models/llava_qwen.py +++ /dev/null @@ -1,335 +0,0 @@ -"""Inference-only LLaVa model compatible with HuggingFace weights.""" - -from typing import List, Iterable, Optional, Tuple - -import numpy as np -import torch -from torch import nn -from transformers import CLIPVisionConfig, CLIPVisionModel, LlavaConfig, Qwen2Config -from transformers.models.llava.modeling_llava import LlavaMultiModalProjector -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - -from sglang.srt.managers.router.infer_batch import ForwardMode -from sglang.srt.managers.router.model_runner import InputMetadata -from sglang.srt.mm_utils import ( - get_anyres_image_grid_shape, - unpad_image, - unpad_image_shape, -) -from sglang.srt.models.qwen2 import Qwen2ForCausalLM - - -class LlavaQwenForCausalLM(nn.Module): - def __init__( - self, - config: LlavaConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.vision_tower = None - if getattr(self.config, "vision_config", None) is None: - self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower) - - if getattr(self.config, "text_config", None) is None: - self.config.text_config = Qwen2Config(self.config._name_or_path) - - self.config.vision_config.hidden_size = config.mm_hidden_size - self.config.text_config.hidden_size = config.hidden_size - - if getattr(self.config, "projector_hidden_act", None) is None: - self.config.projector_hidden_act = "gelu" - - if getattr(self.config, "image_token_index", None) is None: - self.config.image_token_index = 151646 - - self.multi_modal_projector = LlavaMultiModalProjector(config) - self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config) - if "unpad" in getattr(config, "mm_patch_merge_type", ""): - self.language_model.model.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size, dtype=torch.float16) - ) - - def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None): - new_image_feature_len = self.image_feature_len - # now only support spatial_unpad + anyres - if self.mm_patch_merge_type.startswith("spatial"): - height = width = self.num_patches_per_side - if pt_shape[0] > 1: - if self.image_aspect_ratio == "anyres": - num_patch_width, num_patch_height = get_anyres_image_grid_shape( - image_size, - self.image_grid_pinpoints, - self.vision_tower.config.image_size, - ) - if "unpad" in self.mm_patch_merge_type: - h = num_patch_height * height - w = num_patch_width * width - new_h, new_w = unpad_image_shape(h, w, image_size) - new_image_feature_len += new_h * (new_w + 1) - - pad_ids = pad_value * ( - (new_image_feature_len + len(pad_value)) // len(pad_value) - ) - offset = input_ids.index(self.config.image_token_index) - # old_len + pad_len - 1, because we need to remove image_token_id - new_input_ids = ( - input_ids[:offset] - + pad_ids[:new_image_feature_len] - + input_ids[offset + 1 :] - ) - return new_input_ids, offset - - def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor: - image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) - # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated. - - selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer] - if self.vision_feature_select_strategy in ["default", "patch"]: - selected_image_feature = selected_image_feature[:, 1:] - elif self.vision_feature_select_strategy == "full": - selected_image_feature = selected_image_feature - else: - raise ValueError( - f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}" - ) - image_features = self.multi_modal_projector(selected_image_feature) - - return image_features - - def forward( - self, - input_ids: torch.LongTensor, - positions: torch.Tensor, - input_metadata: InputMetadata, - pixel_values: Optional[List[Optional[np.array]]] = None, - image_sizes: Optional[List[List[int]]] = None, - image_offsets: Optional[List[int]] = None, - ) -> torch.Tensor: - if input_metadata.forward_mode == ForwardMode.EXTEND: - bs = input_metadata.batch_size - - # Embed text input - input_embeds = self.language_model.model.embed_tokens(input_ids) - - # Embed vision input - need_vision = ( - (positions[input_metadata.extend_start_loc] < self.image_feature_len) - .cpu() - .numpy() - ) - # FIXME: We need to substract the length of the system prompt - has_pixel = np.array([pixel_values[i] is not None for i in range(bs)]) - need_vision = need_vision & has_pixel - - if need_vision.any(): - pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]] - image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]] - - ########## Encode Image ######## - - if pixel_values[0].ndim == 4: - # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images - np.concatenate(pixel_values, axis=0) - # ndim=4 - concat_images = torch.tensor( - np.concatenate(pixel_values, axis=0), - device=self.vision_tower.device, - ) - image_features = self.encode_images(concat_images) - split_sizes = [image.shape[0] for image in pixel_values] - image_features = torch.split(image_features, split_sizes, dim=0) - # hd image_features: BS, num_patch, 576, 4096 - else: - # normal pixel: BS, C=3, H=336, W=336 - pixel_values = torch.tensor( - np.array(pixel_values), device=self.vision_tower.device - ) - image_features = self.encode_images(pixel_values) - # image_features: BS, 576, 4096 - - if self.mm_patch_merge_type.startswith("spatial"): - new_image_features = [] - for image_idx, image_feature in enumerate(image_features): - if image_feature.shape[0] > 1: - base_image_feature = image_feature[0] - image_feature = image_feature[1:] - height = width = self.num_patches_per_side - assert height * width == base_image_feature.shape[0] - if self.image_aspect_ratio == "anyres": - ( - num_patch_width, - num_patch_height, - ) = get_anyres_image_grid_shape( - image_sizes[image_idx], - self.image_grid_pinpoints, - self.vision_tower.config.image_size, - ) - image_feature = image_feature.view( - num_patch_height, num_patch_width, height, width, -1 - ) - else: - raise NotImplementedError() - if "unpad" in self.mm_patch_merge_type: - image_feature = image_feature.permute( - 4, 0, 2, 1, 3 - ).contiguous() - image_feature = image_feature.flatten(1, 2).flatten( - 2, 3 - ) - image_feature = unpad_image( - image_feature, image_sizes[image_idx] - ) - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[ - :, None, None - ].expand(*image_feature.shape[:-1], 1), - ), - dim=-1, - ) - image_feature = image_feature.flatten(1, 2).transpose( - 0, 1 - ) - else: - image_feature = image_feature.permute( - 0, 2, 1, 3, 4 - ).contiguous() - image_feature = image_feature.flatten(0, 3) - image_feature = torch.cat( - (base_image_feature, image_feature), dim=0 - ) - else: - image_feature = image_feature[0] - if "unpad" in self.mm_patch_merge_type: - image_feature = torch.cat( - ( - image_feature, - self.language_model.model.image_newline[None], - ), - dim=0, - ) - new_image_features.append(image_feature) - image_features = new_image_features - - extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() - pt = 0 - for i in range(bs): - if not need_vision[i]: - continue - - start_idx = extend_start_loc_cpu[i] - pad_len, pad_dim = image_features[pt].shape # 576, 4096 - dim = input_embeds.shape[1] - assert ( - pad_dim == dim - ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) - # Fill in the placeholder for the image - try: - input_embeds[ - start_idx - + image_offsets[i] : start_idx - + image_offsets[i] - + pad_len - ] = image_features[pt] - except RuntimeError as e: - print(f"RuntimeError in llava image encoding: {e}") - print(input_embeds.shape) - print(start_idx, image_offsets[i]) - pt += 1 - - return self.language_model( - input_ids, positions, input_metadata, input_embeds=input_embeds - ) - elif input_metadata.forward_mode == ForwardMode.DECODE: - return self.language_model(input_ids, positions, input_metadata) - - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - # load clip vision model by cfg['mm_vision_tower']: - # huggingface_name or path_of_clip_relative_to_llava_model_dir - vision_path = self.config.mm_vision_tower - self.vision_tower = CLIPVisionModel.from_pretrained( - vision_path, torch_dtype=torch.float16 - ).cuda() - self.vision_tower.eval() - - self.vision_feature_layer = self.config.mm_vision_select_layer - self.vision_feature_select_strategy = self.config.mm_vision_select_feature - self.image_size = self.vision_tower.config.image_size - self.patch_size = self.vision_tower.config.patch_size - - self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat") - self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square") - self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None) - - self.image_feature_len = int((self.image_size / self.patch_size) ** 2) - if self.vision_feature_select_strategy == "patch": - pass - elif self.vision_feature_select_strategy == "cls_patch": - self.image_feature_len += 1 - else: - raise ValueError(f"Unexpected select feature: {self.select_feature}") - - # load mm_projector - projector_weights = { - "model.mm_projector.0": "multi_modal_projector.linear_1", - "model.mm_projector.2": "multi_modal_projector.linear_2", - "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned). - } - params_dict = dict(self.named_parameters()) - weights = list(weights) - for name, loaded_weight in weights: - # FIXME: why projector weights read two times? - if "projector" in name or "vision_tower" in name: - for weight_name, param_name in projector_weights.items(): - if weight_name in name: - name = name.replace(weight_name, param_name) - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) - - # load language model - self.language_model.load_weights(weights) - - monkey_path_clip_vision_embed_forward() - - @property - def num_patches_per_side(self): - return self.image_size // self.patch_size - - -first_call = True - - -def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: - batch_size = pixel_values.shape[0] - - # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G. - global first_call - if first_call: - self.patch_embedding.cpu().float() - first_call = False - pixel_values = pixel_values.to(dtype=torch.float32, device="cpu") - patch_embeds = self.patch_embedding(pixel_values).cuda().half() - - patch_embeds = patch_embeds.flatten(2).transpose(1, 2) - - class_embeds = self.class_embedding.expand(batch_size, 1, -1) - embeddings = torch.cat([class_embeds, patch_embeds], dim=1) - embeddings = embeddings + self.position_embedding(self.position_ids) - return embeddings - - -def monkey_path_clip_vision_embed_forward(): - import transformers - - setattr( - transformers.models.clip.modeling_clip.CLIPVisionEmbeddings, - "forward", - clip_vision_embed_forward, - ) - - -EntryClass = LlavaQwenForCausalLM