[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)
This commit is contained in:
@@ -84,7 +84,7 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_normalized_lora_weight_names(
|
||||
def get_normalized_target_modules(
|
||||
target_modules: Iterable[str],
|
||||
) -> set[str]:
|
||||
"""
|
||||
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
|
||||
|
||||
result = set()
|
||||
for name in target_modules:
|
||||
weight_name = params_mapping.get(name, name)
|
||||
result.add(weight_name)
|
||||
normalized_name = params_mapping.get(name, name)
|
||||
result.add(normalized_name)
|
||||
return result
|
||||
|
||||
|
||||
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]]
|
||||
) -> Optional[str]:
|
||||
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
|
||||
"""
|
||||
Get the weight name in lora_weight_names that can match target_name.
|
||||
Get the target module name in target_modules that can match full_module_name.
|
||||
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
If there is a target module name in target_modules that can match full_module_name, return this name
|
||||
Else raise ValueError.
|
||||
"""
|
||||
for weight_name in lora_weight_names:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
for target_module in target_modules:
|
||||
if target_module in full_module_name:
|
||||
return target_module
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
f"Cannot find target module name for {full_module_name} in {target_modules}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user