Refactor LoRA handling to support adapter tensors in fused format (#6585)
This commit is contained in:
@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
self.stack_qkv_proj(weight_names, layer.weights)
|
||||
self.stack_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
def normalize_qkv_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
):
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||
target_module = set()
|
||||
for weight_name in weight_names:
|
||||
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
|
||||
target_module.add("q_proj")
|
||||
if "v_proj" in weight_name:
|
||||
target_module.add("v_proj")
|
||||
if "qkv_proj" in weight_name:
|
||||
target_module.add("qkv_proj")
|
||||
if len(target_module) == 0:
|
||||
return
|
||||
|
||||
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
elif "qkv_proj" in weight_name:
|
||||
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||
qkv_name = weight_name
|
||||
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||
else:
|
||||
head_size = (
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], weights[kv_name] = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads * 2,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def stack_gate_up_proj(
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
):
|
||||
for weight_name in weight_names:
|
||||
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
|
||||
weights.pop(weight_name)
|
||||
if up_name in weights:
|
||||
weights.pop(up_name)
|
||||
elif "gate_up_proj" in weight_name:
|
||||
# If gate_up_proj is already stacked, we normalize it following the SGL convention
|
||||
gate_up_name = weight_name
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||
# else: "lora_B" is already stacked, no operations is needed.
|
||||
|
||||
Reference in New Issue
Block a user