Refactor LoRA handling to support adapter tensors in fused format (#6585)

This commit is contained in:
Lifu Huang
2025-05-26 21:51:54 -07:00
committed by GitHub
parent 1a8f5f6836
commit 477a101cbd
6 changed files with 86 additions and 31 deletions

View File

@@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module):
for i in range(self.base_hf_config.num_hidden_layers): for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i] layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()] weight_names = [name for name, _ in layer.weights.items()]
self.stack_qkv_proj(weight_names, layer.weights) self.normalize_qkv_proj(weight_names, layer.weights)
self.stack_gate_up_proj(weight_names, layer.weights) self.normalize_gate_up_proj(weight_names, layer.weights)
def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]):
def normalize_qkv_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
):
# Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj
target_module = set() target_module = set()
for weight_name in weight_names: for weight_name in weight_names:
@@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module):
target_module.add("q_proj") target_module.add("q_proj")
if "v_proj" in weight_name: if "v_proj" in weight_name:
target_module.add("v_proj") target_module.add("v_proj")
if "qkv_proj" in weight_name:
target_module.add("qkv_proj")
if len(target_module) == 0: if len(target_module) == 0:
return return
@@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module):
if "k_proj" in target_module: if "k_proj" in target_module:
weights.pop(k_name) weights.pop(k_name)
weights.pop(v_name) weights.pop(v_name)
elif "qkv_proj" in weight_name:
# If qkv_proj is already stacked, we normalize it following the SGL convention.
qkv_name = weight_name
q_name = weight_name.replace("qkv_proj", "q_proj")
k_name = weight_name.replace("qkv_proj", "k_proj")
v_name = weight_name.replace("qkv_proj", "v_proj")
kv_name = weight_name.replace("qkv_proj", "kv_proj")
if "lora_A" in weight_name:
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
else:
head_size = (
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], weights[kv_name] = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2,
],
dim=0,
)
def stack_gate_up_proj( def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor] self, weight_names: List[str], weights: Dict[str, torch.Tensor]
): ):
for weight_name in weight_names: for weight_name in weight_names:
@@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module):
weights.pop(weight_name) weights.pop(weight_name)
if up_name in weights: if up_name in weights:
weights.pop(up_name) weights.pop(up_name)
elif "gate_up_proj" in weight_name:
# If gate_up_proj is already stacked, we normalize it following the SGL convention
gate_up_name = weight_name
if "lora_A" in weight_name:
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
# else: "lora_B" is already stacked, no operations is needed.

View File

@@ -32,7 +32,7 @@ from sglang.srt.lora.utils import (
LoRAType, LoRAType,
get_customized_names_from_hf_names, get_customized_names_from_hf_names,
get_layer_id, get_layer_id,
get_stacked_name, get_normalized_lora_weight_names,
get_weight_name, get_weight_name,
) )
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -101,10 +101,13 @@ class LoRAManager:
self.hf_target_names.update(self.configs[name].target_modules) self.hf_target_names.update(self.configs[name].target_modules)
# Target lora weight names for lora_a and lora_b modules respectively. # Target lora weight names for lora_a and lora_b modules respectively.
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} weights_A: List[str] = []
self.lora_weight_names: Set[Tuple[str]] = set( weights_B: List[str] = []
[get_stacked_name(module) for module in self.hf_target_names] for module in self.hf_target_names:
) lora_A, lora_B = get_normalized_lora_weight_names(module)
weights_A += lora_A
weights_B += lora_B
self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B)
# load all weights to cpu # load all weights to cpu
self.loras: Dict[str, LoRAAdapter] = {} self.loras: Dict[str, LoRAAdapter] = {}
@@ -263,7 +266,18 @@ class LoRAManager:
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
i: [] for i in range(self.base_hf_config.num_hidden_layers) i: [] for i in range(self.base_hf_config.num_hidden_layers)
} }
for module_name, module in self.base_model.named_modules(): for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
# attention stacks (e.g., multimodal models).
# See: https://github.com/sgl-project/sglang/issues/6608
if getattr(
self.base_model, "should_apply_lora", None
) and not self.base_model.should_apply_lora(module_name):
continue
# The module should be converted if it is included in target_names # The module should be converted if it is included in target_names
if module_name.split(".")[-1] in customized_target_names: if module_name.split(".")[-1] in customized_target_names:
layer_id = get_layer_id(module_name) layer_id = get_layer_id(module_name)

View File

@@ -91,18 +91,16 @@ class LoRAMemoryPool:
def init_buffers( def init_buffers(
self, self,
lora_weight_names: Set[Tuple[str]], lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module, base_model: torch.nn.Module,
): ):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load # lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Set[Tuple[str]] = lora_weight_names self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
device = next(base_model.parameters()).device device = next(base_model.parameters()).device
lora_module_A_names = set([name[0] for name in lora_weight_names])
lora_module_B_names = set([name[1] for name in lora_weight_names])
# Init A tensor, column_major=False # Init A tensor, column_major=False
for module_A in lora_module_A_names: for module_A in lora_weight_names[0]:
lora_A_shape = self.get_lora_A_shape(module_A, base_model) lora_A_shape = self.get_lora_A_shape(module_A, base_model)
self.A_buffer[module_A] = [ self.A_buffer[module_A] = [
torch.empty( torch.empty(
@@ -110,10 +108,10 @@ class LoRAMemoryPool:
dtype=self.dtype, dtype=self.dtype,
device=device, device=device,
) )
for i in range(self.num_layer) for _ in range(self.num_layer)
] ]
# Init B tensor, column_major=True # Init B tensor, column_major=True
for module_B in lora_module_B_names: for module_B in lora_weight_names[1]:
lora_B_shape = self.get_lora_B_shape(module_B, base_model) lora_B_shape = self.get_lora_B_shape(module_B, base_model)
self.B_buffer[module_B] = [ self.B_buffer[module_B] = [
torch.empty( torch.empty(

View File

@@ -1,7 +1,7 @@
import re import re
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Set, Tuple from typing import List, Optional, Set, Tuple
import torch import torch
@@ -106,18 +106,22 @@ def get_hidden_dim(
raise NotImplementedError() raise NotImplementedError()
def get_stacked_name(name: str) -> Tuple[str]: def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
""" """
Mapping a target module name to (stacked name for Lora A, stacked name for Lora B) Mapping a target module name to names of the normized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
""" """
params_mapping = { params_mapping = {
"q_proj": ("qkv_proj", "q_proj"), "q_proj": (["qkv_proj"], ["q_proj"]),
"k_proj": ("qkv_proj", "kv_proj"), "k_proj": (["qkv_proj"], ["kv_proj"]),
"v_proj": ("qkv_proj", "kv_proj"), "v_proj": (["qkv_proj"], ["kv_proj"]),
"gate_proj": ("gate_up_proj", "gate_up_proj"), "gate_proj": (["gate_up_proj"], ["gate_up_proj"]),
"up_proj": ("gate_up_proj", "gate_up_proj"), "up_proj": (["gate_up_proj"], ["gate_up_proj"]),
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
} }
return params_mapping.get(name, (name, name)) stacked = params_mapping.get(name, ([name], [name]))
return stacked
def get_stacked_multiply(module_name: str) -> int: def get_stacked_multiply(module_name: str) -> int:
@@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int:
def get_weight_name( def get_weight_name(
target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType
) -> Optional[str]: ) -> Optional[str]:
""" """
target_name is name of a given module, target_name is name of a given module,
@@ -142,9 +146,9 @@ def get_weight_name(
Else raise ValueError. Else raise ValueError.
""" """
idx = 0 if lora_type == LoRAType.LORA_A else 1 idx = 0 if lora_type == LoRAType.LORA_A else 1
for weight_name_pair in lora_weight_names: for weight_name in lora_weight_names[idx]:
if weight_name_pair[idx] in target_name: if weight_name in target_name:
return weight_name_pair[idx] return weight_name
raise ValueError( raise ValueError(
f"Cannot find weight name for {target_name} in {lora_weight_names}" f"Cannot find weight name for {target_name} in {lora_weight_names}"
) )

View File

@@ -17,6 +17,7 @@
import logging import logging
import math import math
import re
from collections.abc import Iterable from collections.abc import Iterable
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"], "gate_up_proj": ["gate_proj", "up_proj"],
} }
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__( def __init__(
self, self,
config: PretrainedConfig, config: PretrainedConfig,
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs) return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> Optional[str]:
return self.lora_pattern.match(module_name)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
# (param_name, shard_name, shard_id) # (param_name, shard_name, shard_id)

View File

@@ -1473,7 +1473,7 @@ class ServerArgs:
self.max_loras_per_batch > 0 self.max_loras_per_batch > 0
# FIXME # FIXME
and (self.lora_paths is None or self.disable_radix_cache) and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress" ), "compatibility of lora and radix attention is in progress"
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
assert self.gpu_id_step >= 1, "gpu_id_step must be positive" assert self.gpu_id_step >= 1, "gpu_id_step must be positive"