Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

This commit is contained in:
Lifu Huang
2025-07-13 18:36:01 -07:00
committed by GitHub
parent b5dd5e8741
commit e2ed9d049a
10 changed files with 840 additions and 227 deletions

View File

@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
@@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError()
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
"""
Mapping a target module name to names of the normalized LoRA weights.
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
stacked = params_mapping.get(name, ([name], [name]))
return stacked
result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result
def get_stacked_multiply(module_name: str) -> int: