Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
def get_normalized_lora_weight_names(
|
||||
target_modules: Iterable[str],
|
||||
) -> Tuple[set[str], set[str]]:
|
||||
"""
|
||||
Mapping a target module name to names of the normalized LoRA weights.
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
}
|
||||
stacked = params_mapping.get(name, ([name], [name]))
|
||||
return stacked
|
||||
|
||||
result = (set(), set())
|
||||
for name in target_modules:
|
||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||
result[0].update(lora_a)
|
||||
result[1].update(lora_b)
|
||||
return result
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
|
||||
Reference in New Issue
Block a user