Refactor LoRA handling to support adapter tensors in fused format (#6585)

This commit is contained in:
Lifu Huang
2025-05-26 21:51:54 -07:00
committed by GitHub
parent 1a8f5f6836
commit 477a101cbd
6 changed files with 86 additions and 31 deletions

View File

@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Set, Tuple
from typing import List, Optional, Set, Tuple
import torch
@@ -106,18 +106,22 @@ def get_hidden_dim(
raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]:
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"""
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
Mapping a target module name to names of the normized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
"q_proj": ("qkv_proj", "q_proj"),
"k_proj": ("qkv_proj", "kv_proj"),
"v_proj": ("qkv_proj", "kv_proj"),
"gate_proj": ("gate_up_proj", "gate_up_proj"),
"up_proj": ("gate_up_proj", "gate_up_proj"),
"q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
return params_mapping.get(name, (name, name))
stacked = params_mapping.get(name, ([name], [name]))
return stacked
def get_stacked_multiply(module_name: str) -> int:
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
) -> Optional[str]:
"""
target_name is name of a given module,
@@ -142,9 +146,9 @@ def get_weight_name(
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
for weight_name in lora_weight_names[idx]:
if weight_name in target_name:
return weight_name
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
)