fix some typos (#6209)
Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca>
This commit is contained in:
@@ -40,7 +40,7 @@ class LoRALayer(nn.Module):
|
||||
self.config: LoRAConfig = config
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
|
||||
# lora weights in cpu. The weights are loaded from checkpoint.
|
||||
# LoRA weights in cpu. The weights are loaded from checkpoint.
|
||||
self.weights: Dict[str, torch.Tensor] = {}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ class LoRAAdapter(nn.Module):
|
||||
|
||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no LoRA attached to k_proj
|
||||
target_module = set()
|
||||
for weight_name in weight_names:
|
||||
if "k_proj" in weight_name:
|
||||
@@ -110,7 +110,7 @@ class LoRAAdapter(nn.Module):
|
||||
return
|
||||
|
||||
for weight_name in weight_names:
|
||||
# We assume every lora adaptor should contain lora modules for q_proj
|
||||
# We assume every LoRA adaptor should contain LoRA modules for q_proj
|
||||
if "q_proj" in weight_name:
|
||||
q_name = weight_name
|
||||
k_name = weight_name.replace("q_proj", "k_proj")
|
||||
@@ -118,7 +118,7 @@ class LoRAAdapter(nn.Module):
|
||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
# If k_proj doesn't have LoRA, initialize it to zero
|
||||
k_proj_weight = (
|
||||
weights[k_name]
|
||||
if "k_proj" in target_module
|
||||
|
||||
Reference in New Issue
Block a user