Refactor LoRA handling to support adapter tensors in fused format (#6585)

This commit is contained in:
Lifu Huang
2025-05-26 21:51:54 -07:00
committed by GitHub
parent 1a8f5f6836
commit 477a101cbd
6 changed files with 86 additions and 31 deletions

View File

@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
self.stack_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)
def normalize_qkv_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set()
for weight_name in weight_names:
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
target_module.add("q_proj")
if "v_proj" in weight_name:
target_module.add("v_proj")
if "qkv_proj" in weight_name:
target_module.add("qkv_proj")
if len(target_module) == 0:
return
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], weights[kv_name] = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2,
],
dim=0,
)
def stack_gate_up_proj(
def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
for weight_name in weight_names:
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
weights.pop(weight_name)
if up_name in weights:
weights.pop(up_name)
elif "gate_up_proj" in weight_name:
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.