[chore] Clean up redundant lora_weight_names concept to simplify code (#9131)

This commit is contained in:
Lifu Huang
2025-08-17 12:36:58 -07:00
committed by GitHub
parent ce3ca9b02f
commit 4b74c3fcca
4 changed files with 55 additions and 54 deletions

View File

@@ -84,7 +84,7 @@ def get_hidden_dim(
raise NotImplementedError()
def get_normalized_lora_weight_names(
def get_normalized_target_modules(
target_modules: Iterable[str],
) -> set[str]:
"""
@@ -100,8 +100,8 @@ def get_normalized_lora_weight_names(
result = set()
for name in target_modules:
weight_name = params_mapping.get(name, name)
result.add(weight_name)
normalized_name = params_mapping.get(name, name)
result.add(normalized_name)
return result
@@ -116,20 +116,18 @@ def get_stacked_multiply(module_name: str) -> int:
return stacked_rank[module_name] if module_name in stacked_rank else 1
def get_weight_name(
target_name: str, lora_weight_names: Tuple[Set[str]]
) -> Optional[str]:
def get_target_module_name(full_module_name: str, target_modules: Set[str]) -> str:
"""
Get the weight name in lora_weight_names that can match target_name.
Get the target module name in target_modules that can match full_module_name.
If there is a weight name in lora_weight_names that can match target_name, return this name
If there is a target module name in target_modules that can match full_module_name, return this name
Else raise ValueError.
"""
for weight_name in lora_weight_names:
if weight_name in target_name:
return weight_name
for target_module in target_modules:
if target_module in full_module_name:
return target_module
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
f"Cannot find target module name for {full_module_name} in {target_modules}"
)