From b2bedcd77982e8d4b7250d21c023a84c1426152f Mon Sep 17 00:00:00 2001 From: ai-modelscope Date: Wed, 26 Mar 2025 13:52:52 +0800 Subject: [PATCH] Add supports_gradient_checkpointing --- configuration_internvl_chat.py | 2 ++ modeling_intern_vit.py | 1 + modeling_internvl_chat.py | 11 +++++++++++ 3 files changed, 14 insertions(+) diff --git a/configuration_internvl_chat.py b/configuration_internvl_chat.py index 9360429..62d8427 100644 --- a/configuration_internvl_chat.py +++ b/configuration_internvl_chat.py @@ -64,6 +64,8 @@ class InternVLChatConfig(PretrainedConfig): self.ps_version = ps_version # pixel shuffle version self.min_dynamic_patch = min_dynamic_patch self.max_dynamic_patch = max_dynamic_patch + # By default, we use tie_word_embeddings=False for models of all sizes. + self.tie_word_embeddings = self.llm_config.tie_word_embeddings logger.info(f'vision_select_layer: {self.select_layer}') logger.info(f'ps_version: {self.ps_version}') diff --git a/modeling_intern_vit.py b/modeling_intern_vit.py index 1c5c043..233b029 100644 --- a/modeling_intern_vit.py +++ b/modeling_intern_vit.py @@ -364,6 +364,7 @@ class InternVisionEncoder(nn.Module): class InternVisionModel(PreTrainedModel): main_input_name = 'pixel_values' _supports_flash_attn_2 = True + supports_gradient_checkpointing = True config_class = InternVisionConfig _no_split_modules = ['InternVisionEncoderLayer'] diff --git a/modeling_internvl_chat.py b/modeling_internvl_chat.py index 55b5871..359db00 100644 --- a/modeling_internvl_chat.py +++ b/modeling_internvl_chat.py @@ -38,6 +38,7 @@ class InternVLChatModel(PreTrainedModel): main_input_name = 'pixel_values' base_model_prefix = 'language_model' _supports_flash_attn_2 = True + supports_gradient_checkpointing = True _no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Phi3DecoderLayer'] def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True): @@ -347,3 +348,13 @@ class InternVLChatModel(PreTrainedModel): ) return outputs + + @property + def lm_head(self): + return self.language_model.get_output_embeddings() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def get_output_embeddings(self): + return self.language_model.get_output_embeddings()