skip vision_model for lora (#10530)
This commit is contained in:
@@ -415,6 +415,10 @@ class LoRAManager:
|
|||||||
replace_submodule(self.base_model, module_name, lora_module)
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
return lora_module
|
return lora_module
|
||||||
|
|
||||||
|
def should_skip_lora_for_vision_model(self, module_name):
|
||||||
|
# TODO: support different vision models
|
||||||
|
return module_name.find("vision_model.model") != -1
|
||||||
|
|
||||||
def init_lora_modules(self):
|
def init_lora_modules(self):
|
||||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||||
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
||||||
@@ -432,6 +436,10 @@ class LoRAManager:
|
|||||||
) and not self.base_model.should_apply_lora(module_name):
|
) and not self.base_model.should_apply_lora(module_name):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Skip vision model
|
||||||
|
if self.should_skip_lora_for_vision_model(module_name):
|
||||||
|
continue
|
||||||
|
|
||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in self.target_modules:
|
if module_name.split(".")[-1] in self.target_modules:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
|
|||||||
Reference in New Issue
Block a user