Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)
This commit is contained in:
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
|
||||
q_name = weight_name
|
||||
k_name = weight_name.replace("q_proj", "k_proj")
|
||||
v_name = weight_name.replace("q_proj", "v_proj")
|
||||
kv_name = weight_name.replace("q_proj", "kv_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
|
||||
if "k_proj" in target_module
|
||||
else torch.zeros_like(weights[v_name])
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = torch.cat(
|
||||
(
|
||||
weights[q_name],
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
weights.pop(q_name)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
else:
|
||||
weights[kv_name] = torch.stack(
|
||||
[
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
weights[qkv_name] = torch.cat(
|
||||
(
|
||||
weights[q_name],
|
||||
k_proj_weight,
|
||||
weights[v_name],
|
||||
),
|
||||
0,
|
||||
)
|
||||
weights.pop(q_name)
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
elif "qkv_proj" in weight_name:
|
||||
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||
qkv_name = weight_name
|
||||
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||
else:
|
||||
head_size = (
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
weights[kv_name] = torch.stack(
|
||||
[k_proj_weight, v_proj_weight],
|
||||
dim=0,
|
||||
)
|
||||
# else: no-op as LoRA B weight is already stacked.
|
||||
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
|
||||
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
||||
if up_name not in weights:
|
||||
weights[up_name] = torch.zeros_like(weights[weight_name])
|
||||
# FIXME: Add gate-only support for flashinfer in future implementations
|
||||
assert self.lora_backend.name == "triton", (
|
||||
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
||||
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
||||
f"or consider implementing custom initialization logic for other backends."
|
||||
)
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
)
|
||||
else:
|
||||
weights[gate_up_name] = torch.stack(
|
||||
[weights[weight_name], weights[up_name]], dim=0
|
||||
)
|
||||
weights[gate_up_name] = torch.cat(
|
||||
(weights[weight_name], weights[up_name]), 0
|
||||
)
|
||||
weights.pop(weight_name)
|
||||
if up_name in weights:
|
||||
weights.pop(up_name)
|
||||
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
|
||||
gate_up_name = weight_name
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||
else:
|
||||
output_dim = weights[gate_up_name].shape[0] // 2
|
||||
weights[gate_up_name] = torch.stack(
|
||||
[
|
||||
weights[gate_up_name][:output_dim, :],
|
||||
weights[gate_up_name][output_dim:, :],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
# else: no-op as LoRA B weight is already stacked.
|
||||
|
||||
Reference in New Issue
Block a user