Feature/make PEFT adapter module format compatibile (#11080)
This commit is contained in:
@@ -98,6 +98,7 @@ def get_normalized_target_modules(
|
|||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||||
|
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
|
||||||
"""
|
"""
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
"q_proj": "qkv_proj",
|
"q_proj": "qkv_proj",
|
||||||
@@ -109,7 +110,8 @@ def get_normalized_target_modules(
|
|||||||
|
|
||||||
result = set()
|
result = set()
|
||||||
for name in target_modules:
|
for name in target_modules:
|
||||||
normalized_name = params_mapping.get(name, name)
|
base_name = name.split(".")[-1]
|
||||||
|
normalized_name = params_mapping.get(base_name, base_name)
|
||||||
result.add(normalized_name)
|
result.add(normalized_name)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user