Refactor LoRA handling to support adapter tensors in fused format (#6585)
This commit is contained in:
@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
|
|||||||
for i in range(self.base_hf_config.num_hidden_layers):
|
for i in range(self.base_hf_config.num_hidden_layers):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
weight_names = [name for name, _ in layer.weights.items()]
|
weight_names = [name for name, _ in layer.weights.items()]
|
||||||
self.stack_qkv_proj(weight_names, layer.weights)
|
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||||
self.stack_gate_up_proj(weight_names, layer.weights)
|
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||||
|
|
||||||
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
|
|
||||||
|
|
||||||
|
def normalize_qkv_proj(
|
||||||
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||||
|
):
|
||||||
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
|
||||||
target_module = set()
|
target_module = set()
|
||||||
for weight_name in weight_names:
|
for weight_name in weight_names:
|
||||||
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
|
|||||||
target_module.add("q_proj")
|
target_module.add("q_proj")
|
||||||
if "v_proj" in weight_name:
|
if "v_proj" in weight_name:
|
||||||
target_module.add("v_proj")
|
target_module.add("v_proj")
|
||||||
|
if "qkv_proj" in weight_name:
|
||||||
|
target_module.add("qkv_proj")
|
||||||
if len(target_module) == 0:
|
if len(target_module) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
|
|||||||
if "k_proj" in target_module:
|
if "k_proj" in target_module:
|
||||||
weights.pop(k_name)
|
weights.pop(k_name)
|
||||||
weights.pop(v_name)
|
weights.pop(v_name)
|
||||||
|
elif "qkv_proj" in weight_name:
|
||||||
|
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
||||||
|
qkv_name = weight_name
|
||||||
|
q_name = weight_name.replace("qkv_proj", "q_proj")
|
||||||
|
k_name = weight_name.replace("qkv_proj", "k_proj")
|
||||||
|
v_name = weight_name.replace("qkv_proj", "v_proj")
|
||||||
|
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
||||||
|
if "lora_A" in weight_name:
|
||||||
|
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
||||||
|
else:
|
||||||
|
head_size = (
|
||||||
|
self.base_hf_config.hidden_size
|
||||||
|
// self.base_hf_config.num_attention_heads
|
||||||
|
)
|
||||||
|
weights[q_name], weights[kv_name] = torch.split(
|
||||||
|
weights[qkv_name],
|
||||||
|
[
|
||||||
|
head_size * self.base_hf_config.num_attention_heads,
|
||||||
|
head_size * self.base_hf_config.num_key_value_heads * 2,
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|
||||||
def stack_gate_up_proj(
|
def normalize_gate_up_proj(
|
||||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||||
):
|
):
|
||||||
for weight_name in weight_names:
|
for weight_name in weight_names:
|
||||||
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
|
|||||||
weights.pop(weight_name)
|
weights.pop(weight_name)
|
||||||
if up_name in weights:
|
if up_name in weights:
|
||||||
weights.pop(up_name)
|
weights.pop(up_name)
|
||||||
|
elif "gate_up_proj" in weight_name:
|
||||||
|
# If gate_up_proj is already stacked, we normalize it following the SGL convention
|
||||||
|
gate_up_name = weight_name
|
||||||
|
if "lora_A" in weight_name:
|
||||||
|
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||||
|
# else: "lora_B" is already stacked, no operations is needed.
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
|
|||||||
LoRAType,
|
LoRAType,
|
||||||
get_customized_names_from_hf_names,
|
get_customized_names_from_hf_names,
|
||||||
get_layer_id,
|
get_layer_id,
|
||||||
get_stacked_name,
|
get_normalized_lora_weight_names,
|
||||||
get_weight_name,
|
get_weight_name,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -101,10 +101,13 @@ class LoRAManager:
|
|||||||
self.hf_target_names.update(self.configs[name].target_modules)
|
self.hf_target_names.update(self.configs[name].target_modules)
|
||||||
|
|
||||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
weights_A: List[str] = []
|
||||||
self.lora_weight_names: Set[Tuple[str]] = set(
|
weights_B: List[str] = []
|
||||||
[get_stacked_name(module) for module in self.hf_target_names]
|
for module in self.hf_target_names:
|
||||||
)
|
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||||
|
weights_A += lora_A
|
||||||
|
weights_B += lora_B
|
||||||
|
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
|
||||||
|
|
||||||
# load all weights to cpu
|
# load all weights to cpu
|
||||||
self.loras: Dict[str, LoRAAdapter] = {}
|
self.loras: Dict[str, LoRAAdapter] = {}
|
||||||
@@ -263,7 +266,18 @@ class LoRAManager:
|
|||||||
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
||||||
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
for module_name, module in self.base_model.named_modules():
|
for module_name, module in self.base_model.named_modules():
|
||||||
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
||||||
|
# should_apply_lora function to support mapping by full module name instead
|
||||||
|
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
|
||||||
|
# attention stacks (e.g., multimodal models).
|
||||||
|
# See: https://github.com/sgl-project/sglang/issues/6608
|
||||||
|
if getattr(
|
||||||
|
self.base_model, "should_apply_lora", None
|
||||||
|
) and not self.base_model.should_apply_lora(module_name):
|
||||||
|
continue
|
||||||
|
|
||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in customized_target_names:
|
if module_name.split(".")[-1] in customized_target_names:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
|
|||||||
@@ -91,18 +91,16 @@ class LoRAMemoryPool:
|
|||||||
|
|
||||||
def init_buffers(
|
def init_buffers(
|
||||||
self,
|
self,
|
||||||
lora_weight_names: Set[Tuple[str]],
|
lora_weight_names: Tuple[Set[str]],
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
):
|
):
|
||||||
|
|
||||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||||
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names
|
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||||
device = next(base_model.parameters()).device
|
device = next(base_model.parameters()).device
|
||||||
lora_module_A_names = set([name[0] for name in lora_weight_names])
|
|
||||||
lora_module_B_names = set([name[1] for name in lora_weight_names])
|
|
||||||
# Init A tensor, column_major=False
|
# Init A tensor, column_major=False
|
||||||
for module_A in lora_module_A_names:
|
for module_A in lora_weight_names[0]:
|
||||||
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
lora_A_shape = self.get_lora_A_shape(module_A, base_model)
|
||||||
self.A_buffer[module_A] = [
|
self.A_buffer[module_A] = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
@@ -110,10 +108,10 @@ class LoRAMemoryPool:
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
for i in range(self.num_layer)
|
for _ in range(self.num_layer)
|
||||||
]
|
]
|
||||||
# Init B tensor, column_major=True
|
# Init B tensor, column_major=True
|
||||||
for module_B in lora_module_B_names:
|
for module_B in lora_weight_names[1]:
|
||||||
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
lora_B_shape = self.get_lora_B_shape(module_B, base_model)
|
||||||
self.B_buffer[module_B] = [
|
self.B_buffer[module_B] = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -106,18 +106,22 @@ def get_hidden_dim(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_name(name: str) -> Tuple[str]:
|
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||||
"""
|
"""
|
||||||
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B)
|
Mapping a target module name to names of the normized LoRA weights.
|
||||||
|
Returned tuple contains (name for Lora A, name for Lora B)
|
||||||
"""
|
"""
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
"q_proj": ("qkv_proj", "q_proj"),
|
"q_proj": (["qkv_proj"], ["q_proj"]),
|
||||||
"k_proj": ("qkv_proj", "kv_proj"),
|
"k_proj": (["qkv_proj"], ["kv_proj"]),
|
||||||
"v_proj": ("qkv_proj", "kv_proj"),
|
"v_proj": (["qkv_proj"], ["kv_proj"]),
|
||||||
"gate_proj": ("gate_up_proj", "gate_up_proj"),
|
"gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||||
"up_proj": ("gate_up_proj", "gate_up_proj"),
|
"up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||||
|
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||||
|
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||||
}
|
}
|
||||||
return params_mapping.get(name, (name, name))
|
stacked = params_mapping.get(name, ([name], [name]))
|
||||||
|
return stacked
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_multiply(module_name: str) -> int:
|
def get_stacked_multiply(module_name: str) -> int:
|
||||||
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
|
|||||||
|
|
||||||
|
|
||||||
def get_weight_name(
|
def get_weight_name(
|
||||||
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType
|
target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
target_name is name of a given module,
|
target_name is name of a given module,
|
||||||
@@ -142,9 +146,9 @@ def get_weight_name(
|
|||||||
Else raise ValueError.
|
Else raise ValueError.
|
||||||
"""
|
"""
|
||||||
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
idx = 0 if lora_type == LoRAType.LORA_A else 1
|
||||||
for weight_name_pair in lora_weight_names:
|
for weight_name in lora_weight_names[idx]:
|
||||||
if weight_name_pair[idx] in target_name:
|
if weight_name in target_name:
|
||||||
return weight_name_pair[idx]
|
return weight_name
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
f"Cannot find weight name for {target_name} in {lora_weight_names}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
|
|||||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
lora_pattern = re.compile(
|
||||||
|
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: PretrainedConfig,
|
config: PretrainedConfig,
|
||||||
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
|
|||||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
|
def should_apply_lora(self, module_name: str) -> Optional[str]:
|
||||||
|
return self.lora_pattern.match(module_name)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
|
|||||||
@@ -1473,7 +1473,7 @@ class ServerArgs:
|
|||||||
self.max_loras_per_batch > 0
|
self.max_loras_per_batch > 0
|
||||||
# FIXME
|
# FIXME
|
||||||
and (self.lora_paths is None or self.disable_radix_cache)
|
and (self.lora_paths is None or self.disable_radix_cache)
|
||||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
), "compatibility of lora and radix attention is in progress"
|
||||||
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
||||||
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user