Fix incorrect LoRA weight loading for fused gate_up_proj (#6734)

This commit is contained in:
Lifu Huang
2025-05-31 13:41:44 -07:00
committed by GitHub
parent 888cb175a6
commit 094fbdacd5
4 changed files with 29 additions and 14 deletions

View File

@@ -209,4 +209,12 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.
else:
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)