[Feature] Support Tensor Parallelism and Weight Slicing for Lora (#4274)
Co-authored-by: ShenAo1111 <1377693092@qq.com> Co-authored-by: Baizhou Zhang <sobereddiezhang@gmail.com>
This commit is contained in:
@@ -133,9 +133,20 @@ def get_weight_name(
|
||||
target_name is name of a given module,
|
||||
lora_weight_names is a set of lora stacked name pairs (see get_stacked_name method above)
|
||||
If there is a weight name in lora_weight_names that can match target_name, return this name
|
||||
Else return None
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name_pair in lora_weight_names:
|
||||
if weight_name_pair[idx] in target_name:
|
||||
return weight_name_pair[idx]
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
)
|
||||
|
||||
|
||||
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
||||
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
||||
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
||||
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
||||
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|
||||
|
||||
Reference in New Issue
Block a user