Add supports_gradient_checkpointing
This commit is contained in:
@@ -138,6 +138,7 @@ def split_model(model_name):
|
|||||||
device_map['language_model.model.embed_tokens'] = 0
|
device_map['language_model.model.embed_tokens'] = 0
|
||||||
device_map['language_model.output'] = 0
|
device_map['language_model.output'] = 0
|
||||||
device_map['language_model.model.norm'] = 0
|
device_map['language_model.model.norm'] = 0
|
||||||
|
device_map['language_model.model.rotary_emb'] = 0
|
||||||
device_map['language_model.lm_head'] = 0
|
device_map['language_model.lm_head'] = 0
|
||||||
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
|
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user