[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)

Co-authored-by: ShenAo1111 <1377693092@qq.com>
Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
aoshen524
2025-03-18 23:33:07 -04:00
committed by GitHub
parent 3196999f63
commit 588865f0e0
13 changed files with 528 additions and 103 deletions

View File

@@ -133,9 +133,20 @@ def get_weight_name(
target_name is name of a given module,
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
If there is a weight name in lora_weight_names that can match target_name, return this name
Else return None
Else raise ValueError.
"""
idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names:
if weight_name_pair[idx] in target_name:
return weight_name_pair[idx]
raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}"
)
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]