Improve LoRA Perf by Deprecating FlashInfer and Eliminating Redundant Tensor Ops (#8940)

This commit is contained in:
Lifu Huang
2025-08-10 01:04:45 -07:00
committed by GitHub
parent 6b847a9a05
commit f8a173bb50
10 changed files with 137 additions and 525 deletions

View File

@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
q_name = weight_name
k_name = weight_name.replace("q_proj", "k_proj")
v_name = weight_name.replace("q_proj", "v_proj")
kv_name = weight_name.replace("q_proj", "kv_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# If k_proj doesn't have lora, initialize it to zero
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module
else torch.zeros_like(weights[v_name])
)
if "lora_A" in weight_name:
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
else:
weights[kv_name] = torch.stack(
[
k_proj_weight,
weights[v_name],
],
dim=0,
)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
weights[qkv_name] = torch.cat(
(
weights[q_name],
k_proj_weight,
weights[v_name],
),
0,
)
weights.pop(q_name)
if "k_proj" in target_module:
weights.pop(k_name)
weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads,
head_size * self.base_hf_config.num_key_value_heads,
],
dim=0,
)
weights[kv_name] = torch.stack(
[k_proj_weight, v_proj_weight],
dim=0,
)
# else: no-op as LoRA B weight is already stacked.
def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
if up_name not in weights:
weights[up_name] = torch.zeros_like(weights[weight_name])
# FIXME: Add gate-only support for flashinfer in future implementations
assert self.lora_backend.name == "triton", (
f"LoRA weight initialization currently only supported for 'triton' backend. "
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
f"or consider implementing custom initialization logic for other backends."
)
if "lora_A" in weight_name:
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
else:
weights[gate_up_name] = torch.stack(
[weights[weight_name], weights[up_name]], dim=0
)
weights[gate_up_name] = torch.cat(
(weights[weight_name], weights[up_name]), 0
)
weights.pop(weight_name)
if up_name in weights:
weights.pop(up_name)
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
else:
output_dim = weights[gate_up_name].shape[0] // 2
weights[gate_up_name] = torch.stack(
[
weights[gate_up_name][:output_dim, :],
weights[gate_up_name][output_dim:, :],
],
dim=0,
)
# else: no-op as LoRA B weight is already stacked.