Refactor LoRA handling to support adapter tensors in fused format (#6585)
This commit is contained in:
@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
self.stack_qkv_proj(weight_names, layer.weights)
|
||||
self.stack_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
||||
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
def normalize_qkv_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
):
|
||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||
target_module = set()
|
||||
for weight_name in weight_names:
|
||||
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
|
||||
target_module.add("q_proj")
|
||||
if "v_proj" in weight_name:
|
||||
target_module.add("v_proj")
|
||||
if "qkv_proj" in weight_name:
|
||||
target_module.add("qkv_proj")
|
||||
if len(target_module) == 0:
|
||||
return
|
||||
|
||||
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
|
||||
if "k_proj" in target_module:
|
||||
weights.pop(k_name)
|
||||
weights.pop(v_name)
|
||||
elif "qkv_proj" in weight_name:
|
||||
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||
qkv_name = weight_name
|
||||
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
||||
if "lora_A" in weight_name:
|
||||
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||
else:
|
||||
head_size = (
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], weights[kv_name] = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads * 2,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def stack_gate_up_proj(
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
):
|
||||
for weight_name in weight_names:
|
||||
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
|
||||
weights.pop(weight_name)
|
||||
if up_name in weights:
|
||||
weights.pop(up_name)
|
||||
elif "gate_up_proj" in weight_name:
|
||||
# If gate_up_proj is already stacked, we normalize it following the SGL convention
|
||||
gate_up_name = weight_name
|
||||
if "lora_A" in weight_name:
|
||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||
# else: "lora_B" is already stacked, no operations is needed.
|
||||
|
||||
@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
|
||||
LoRAType,
|
||||
get_customized_names_from_hf_names,
|
||||
get_layer_id,
|
||||
get_stacked_name,
|
||||
get_normalized_lora_weight_names,
|
||||
get_weight_name,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -101,10 +101,13 @@ class LoRAManager:
|
||||
self.hf_target_names.update(self.configs[name].target_modules)
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = set(
|
||||
[get_stacked_name(module) for module in self.hf_target_names]
|
||||
)
|
||||
weights_A: List[str] = []
|
||||
weights_B: List[str] = []
|
||||
for module in self.hf_target_names:
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||
weights_A += lora_A
|
||||
weights_B += lora_B
|
||||
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
||||
|
||||
# load all weights to cpu
|
||||
self.loras: Dict[str, LoRAAdapter] = {}
|
||||
@@ -263,7 +266,18 @@ class LoRAManager:
|
||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
||||
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||
# should_apply_lora function to support mapping by full module name instead
|
||||
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
|
||||
# attention stacks (e.g., multimodal models).
|
||||
# See: https://github.com/sgl-project/sglang/issues/6608
|
||||
if getattr(
|
||||
self.base_model, "should_apply_lora", None
|
||||
) and not self.base_model.should_apply_lora(module_name):
|
||||
continue
|
||||
|
||||
# The module should be converted if it is included in target_names
|
||||
if module_name.split(".")[-1] in customized_target_names:
|
||||
layer_id = get_layer_id(module_name)
|
||||
|
||||
@@ -91,18 +91,16 @@ class LoRAMemoryPool:
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Set[Tuple[str]],
|
||||
lora_weight_names: Tuple[Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||
device = next(base_model.parameters()).device
|
||||
lora_module_A_names = set([name[0] for name in lora_weight_names])
|
||||
lora_module_B_names = set([name[1] for name in lora_weight_names])
|
||||
# Init A tensor, column_major=False
|
||||
for module_A in lora_module_A_names:
|
||||
for module_A in lora_weight_names[0]:
|
||||
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
||||
self.A_buffer[module_A] = [
|
||||
torch.empty(
|
||||
@@ -110,10 +108,10 @@ class LoRAMemoryPool:
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for i in range(self.num_layer)
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
# Init B tensor, column_major=True
|
||||
for module_B in lora_module_B_names:
|
||||
for module_B in lora_weight_names[1]:
|
||||
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
||||
self.B_buffer[module_B] = [
|
||||
torch.empty(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Set, Tuple
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,18 +106,22 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_stacked_name(name: str) -> Tuple[str]:
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"""
|
||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
||||
Mapping a target module name to names of the normized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
"q_proj": ("qkv_proj", "q_proj"),
|
||||
"k_proj": ("qkv_proj", "kv_proj"),
|
||||
"v_proj": ("qkv_proj", "kv_proj"),
|
||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
||||
"q_proj": (["qkv_proj"], ["q_proj"]),
|
||||
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
||||
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
}
|
||||
return params_mapping.get(name, (name, name))
|
||||
stacked = params_mapping.get(name, ([name], [name]))
|
||||
return stacked
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
|
||||
|
||||
|
||||
def get_weight_name(
|
||||
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
||||
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
target_name is name of a given module,
|
||||
@@ -142,9 +146,9 @@ def get_weight_name(
|
||||
Else raise ValueError.
|
||||
"""
|
||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||
for weight_name_pair in lora_weight_names:
|
||||
if weight_name_pair[idx] in target_name:
|
||||
return weight_name_pair[idx]
|
||||
for weight_name in lora_weight_names[idx]:
|
||||
if weight_name in target_name:
|
||||
return weight_name
|
||||
raise ValueError(
|
||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
lora_pattern = re.compile(
|
||||
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> Optional[str]:
|
||||
return self.lora_pattern.match(module_name)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
@@ -1473,7 +1473,7 @@ class ServerArgs:
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
), "compatibility of lora and radix attention is in progress"
|
||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user