Refactor LoRA handling to support adapter tensors in fused format (#6585)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,18 +106,22 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_stacked_name(name: str) -> Tuple[str]:
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||
Mapping a target module name to names of the normized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"q_proj": (["qkv_proj"], ["q_proj"]),
|
||||
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
stacked = params_mapping.get(name, ([name], [name]))
|
||||
return stacked
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
@@ -142,9 +146,9 @@ def get_weight_name(
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name_pair in lora_weight_names:
|
||||
if weight_name_pair[idx] in target_name:
|
||||
return weight_name_pair[idx]
|
||||
for weight_name in lora_weight_names[idx]:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user