model(vlm): mistral 3.1 (#5099)
Co-authored-by: KivenChen <sleigh-queue-0y@icloud.com>
This commit is contained in:
@@ -135,7 +135,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
"""
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
||||
if self.vision_feature_select_strategy in ["default", "patch"]:
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
@@ -146,7 +145,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
||||
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
||||
)
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -613,6 +611,10 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
|
||||
MULTIMODAL_PROJECTOR_TYPE = LlavaMultiModalProjector
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.torch_dtype
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
||||
if hasattr(self.vision_tower, "pad_input_ids"):
|
||||
return self.vision_tower.pad_input_ids(input_ids, image_inputs)
|
||||
@@ -672,11 +674,17 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
assert hasattr(config, "text_config")
|
||||
assert hasattr(config, "vision_config")
|
||||
self.config = config
|
||||
self.text_config = config.text_config
|
||||
self.vision_config = config.vision_config
|
||||
self.text_config = self.config.text_config
|
||||
self.vision_config = self.config.vision_config
|
||||
self.torch_dtype = getattr(self.config, "torch_dtype")
|
||||
|
||||
if not getattr(self.text_config, "torch_dtype"):
|
||||
self.text_config.torch_dtype = self.torch_dtype
|
||||
if not getattr(self.vision_config, "torch_dtype"):
|
||||
self.vision_config.torch_dtype = self.torch_dtype
|
||||
|
||||
if not hasattr(self.config, "vocab_size"):
|
||||
self.config.vocab_size = self.config.text_config.vocab_size
|
||||
self.config.vocab_size = self.text_config.vocab_size
|
||||
if not hasattr(self.config, "image_aspect_ratio"):
|
||||
self.config.image_aspect_ratio = "anyres"
|
||||
if not hasattr(self.config, "image_grid_pinpoints"):
|
||||
@@ -697,39 +705,39 @@ class LlavaForConditionalGeneration(LlavaBaseForCausalLM):
|
||||
if not hasattr(self.config, "projector_hidden_act"):
|
||||
self.config.projector_hidden_act = "gelu"
|
||||
|
||||
self.vision_feature_layer = getattr(config, "vision_feature_layer", -1)
|
||||
self.vision_feature_layer = getattr(self.config, "vision_feature_layer", -1)
|
||||
self.vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", "full"
|
||||
self.config, "vision_feature_select_strategy", "full"
|
||||
)
|
||||
self.image_size = self.config.vision_config.image_size
|
||||
self.patch_size = self.config.vision_config.patch_size
|
||||
self.image_size = self.vision_config.image_size
|
||||
self.patch_size = self.vision_config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = config.mm_patch_merge_type
|
||||
self.image_aspect_ratio = config.image_aspect_ratio
|
||||
self.image_grid_pinpoints = config.image_grid_pinpoints
|
||||
self.mm_patch_merge_type = self.config.mm_patch_merge_type
|
||||
self.image_aspect_ratio = self.config.image_aspect_ratio
|
||||
self.image_grid_pinpoints = self.config.image_grid_pinpoints
|
||||
|
||||
self.image_feature_len = int((self.image_size // self.patch_size) ** 2)
|
||||
|
||||
self.multi_modal_projector = self.MULTIMODAL_PROJECTOR_TYPE(config)
|
||||
|
||||
language_model_cls = self._get_sgl_model_cls(
|
||||
config.text_config, AutoModelForCausalLM
|
||||
self.text_config, AutoModelForCausalLM
|
||||
)
|
||||
vision_model_cls = self._get_sgl_model_cls(config.vision_config, AutoModel)
|
||||
vision_model_cls = self._get_sgl_model_cls(self.vision_config, AutoModel)
|
||||
self.language_model = language_model_cls(
|
||||
config.text_config,
|
||||
self.text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
self.vision_tower = vision_model_cls(
|
||||
config.vision_config,
|
||||
self.vision_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("vision_tower", prefix),
|
||||
)
|
||||
|
||||
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
||||
if "unpad" in getattr(self.config, "mm_patch_merge_type", ""):
|
||||
self.language_model.model.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
||||
torch.empty(self.text_config.hidden_size, dtype=self.torch_dtype)
|
||||
)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user