Add supports_gradient_checkpointing
This commit is contained in:
@@ -37,6 +37,7 @@ class InternVLChatModel(PreTrainedModel):
|
||||
main_input_name = 'pixel_values'
|
||||
base_model_prefix = 'language_model'
|
||||
_supports_flash_attn_2 = True
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ['InternVisionModel', 'LlamaDecoderLayer', 'Qwen2DecoderLayer']
|
||||
|
||||
def __init__(self, config: InternVLChatConfig, vision_model=None, language_model=None, use_flash_attn=True):
|
||||
@@ -346,3 +347,13 @@ class InternVLChatModel(PreTrainedModel):
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def lm_head(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.language_model.get_input_embeddings()
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.language_model.get_output_embeddings()
|
||||
|
||||
Reference in New Issue
Block a user