From 477a101cbda0ecab08ab3a820cd7df16038ec214 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Mon, 26 May 2025 21:51:54 -0700 Subject: [PATCH] Refactor LoRA handling to support adapter tensors in fused format (#6585) --- python/sglang/srt/lora/lora.py | 41 ++++++++++++++++++++++---- python/sglang/srt/lora/lora_manager.py | 24 +++++++++++---- python/sglang/srt/lora/mem_pool.py | 12 ++++---- python/sglang/srt/lora/utils.py | 30 +++++++++++-------- python/sglang/srt/models/phi4mm.py | 8 +++++ python/sglang/srt/server_args.py | 2 +- 6 files changed, 86 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index b0db40d6a..38cc08f71 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -92,11 +92,12 @@ class LoRAAdapter(nn.Module): for i in range(self.base_hf_config.num_hidden_layers): layer = self.layers[i] weight_names = [name for name, _ in layer.weights.items()] - self.stack_qkv_proj(weight_names, layer.weights) - self.stack_gate_up_proj(weight_names, layer.weights) - - def stack_qkv_proj(self, weight_names: List[str], weights: Dict[str, torch.Tensor]): + self.normalize_qkv_proj(weight_names, layer.weights) + self.normalize_gate_up_proj(weight_names, layer.weights) + def normalize_qkv_proj( + self, weight_names: List[str], weights: Dict[str, torch.Tensor] + ): # Collect target q/k/v modules. This process is necessary since there might be no lora attached to k_proj target_module = set() for weight_name in weight_names: @@ -106,6 +107,8 @@ class LoRAAdapter(nn.Module): target_module.add("q_proj") if "v_proj" in weight_name: target_module.add("v_proj") + if "qkv_proj" in weight_name: + target_module.add("qkv_proj") if len(target_module) == 0: return @@ -148,8 +151,30 @@ class LoRAAdapter(nn.Module): if "k_proj" in target_module: weights.pop(k_name) weights.pop(v_name) + elif "qkv_proj" in weight_name: + # If qkv_proj is already stacked, we normalize it following the SGL convention. + qkv_name = weight_name + q_name = weight_name.replace("qkv_proj", "q_proj") + k_name = weight_name.replace("qkv_proj", "k_proj") + v_name = weight_name.replace("qkv_proj", "v_proj") + kv_name = weight_name.replace("qkv_proj", "kv_proj") + if "lora_A" in weight_name: + weights[qkv_name] = weights[qkv_name].repeat(3, 1) + else: + head_size = ( + self.base_hf_config.hidden_size + // self.base_hf_config.num_attention_heads + ) + weights[q_name], weights[kv_name] = torch.split( + weights[qkv_name], + [ + head_size * self.base_hf_config.num_attention_heads, + head_size * self.base_hf_config.num_key_value_heads * 2, + ], + dim=0, + ) - def stack_gate_up_proj( + def normalize_gate_up_proj( self, weight_names: List[str], weights: Dict[str, torch.Tensor] ): for weight_name in weight_names: @@ -179,3 +204,9 @@ class LoRAAdapter(nn.Module): weights.pop(weight_name) if up_name in weights: weights.pop(up_name) + elif "gate_up_proj" in weight_name: + # If gate_up_proj is already stacked, we normalize it following the SGL convention + gate_up_name = weight_name + if "lora_A" in weight_name: + weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) + # else: "lora_B" is already stacked, no operations is needed. diff --git a/python/sglang/srt/lora/lora_manager.py b/python/sglang/srt/lora/lora_manager.py index 68b2a3621..45050df53 100644 --- a/python/sglang/srt/lora/lora_manager.py +++ b/python/sglang/srt/lora/lora_manager.py @@ -32,7 +32,7 @@ from sglang.srt.lora.utils import ( LoRAType, get_customized_names_from_hf_names, get_layer_id, - get_stacked_name, + get_normalized_lora_weight_names, get_weight_name, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -101,10 +101,13 @@ class LoRAManager: self.hf_target_names.update(self.configs[name].target_modules) # Target lora weight names for lora_a and lora_b modules respectively. - # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")} - self.lora_weight_names: Set[Tuple[str]] = set( - [get_stacked_name(module) for module in self.hf_target_names] - ) + weights_A: List[str] = [] + weights_B: List[str] = [] + for module in self.hf_target_names: + lora_A, lora_B = get_normalized_lora_weight_names(module) + weights_A += lora_A + weights_B += lora_B + self.lora_weight_names: Tuple[Set[str]] = set(weights_A), set(weights_B) # load all weights to cpu self.loras: Dict[str, LoRAAdapter] = {} @@ -263,7 +266,18 @@ class LoRAManager: self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = { i: [] for i in range(self.base_hf_config.num_hidden_layers) } + for module_name, module in self.base_model.named_modules(): + # TODO (lifuhuang): in the future, we should consider generalizing the + # should_apply_lora function to support mapping by full module name instead + # of just the last part (e.g., "qkv_proj") to support scenarios with multiple + # attention stacks (e.g., multimodal models). + # See: https://github.com/sgl-project/sglang/issues/6608 + if getattr( + self.base_model, "should_apply_lora", None + ) and not self.base_model.should_apply_lora(module_name): + continue + # The module should be converted if it is included in target_names if module_name.split(".")[-1] in customized_target_names: layer_id = get_layer_id(module_name) diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 71495acca..6db4e14f3 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -91,18 +91,16 @@ class LoRAMemoryPool: def init_buffers( self, - lora_weight_names: Set[Tuple[str]], + lora_weight_names: Tuple[Set[str]], base_model: torch.nn.Module, ): # lora_weight_names is a set of name pairs indicating each pair of lora modules to load # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")} - self.lora_weight_names: Set[Tuple[str]] = lora_weight_names + self.lora_weight_names: Tuple[Set[str]] = lora_weight_names device = next(base_model.parameters()).device - lora_module_A_names = set([name[0] for name in lora_weight_names]) - lora_module_B_names = set([name[1] for name in lora_weight_names]) # Init A tensor, column_major=False - for module_A in lora_module_A_names: + for module_A in lora_weight_names[0]: lora_A_shape = self.get_lora_A_shape(module_A, base_model) self.A_buffer[module_A] = [ torch.empty( @@ -110,10 +108,10 @@ class LoRAMemoryPool: dtype=self.dtype, device=device, ) - for i in range(self.num_layer) + for _ in range(self.num_layer) ] # Init B tensor, column_major=True - for module_B in lora_module_B_names: + for module_B in lora_weight_names[1]: lora_B_shape = self.get_lora_B_shape(module_B, base_model) self.B_buffer[module_B] = [ torch.empty( diff --git a/python/sglang/srt/lora/utils.py b/python/sglang/srt/lora/utils.py index 3f1f3558d..445b89604 100644 --- a/python/sglang/srt/lora/utils.py +++ b/python/sglang/srt/lora/utils.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass from enum import Enum -from typing import Optional, Set, Tuple +from typing import List, Optional, Set, Tuple import torch @@ -106,18 +106,22 @@ def get_hidden_dim( raise NotImplementedError() -def get_stacked_name(name: str) -> Tuple[str]: +def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]: """ - Mapping a target module name to (stacked name for Lora A, stacked name for Lora B) + Mapping a target module name to names of the normized LoRA weights. + Returned tuple contains (name for Lora A, name for Lora B) """ params_mapping = { - "q_proj": ("qkv_proj", "q_proj"), - "k_proj": ("qkv_proj", "kv_proj"), - "v_proj": ("qkv_proj", "kv_proj"), - "gate_proj": ("gate_up_proj", "gate_up_proj"), - "up_proj": ("gate_up_proj", "gate_up_proj"), + "q_proj": (["qkv_proj"], ["q_proj"]), + "k_proj": (["qkv_proj"], ["kv_proj"]), + "v_proj": (["qkv_proj"], ["kv_proj"]), + "gate_proj": (["gate_up_proj"], ["gate_up_proj"]), + "up_proj": (["gate_up_proj"], ["gate_up_proj"]), + "qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]), + "gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]), } - return params_mapping.get(name, (name, name)) + stacked = params_mapping.get(name, ([name], [name])) + return stacked def get_stacked_multiply(module_name: str) -> int: @@ -133,7 +137,7 @@ def get_stacked_multiply(module_name: str) -> int: def get_weight_name( - target_name: str, lora_weight_names: Set[Tuple[str]], lora_type: LoRAType + target_name: str, lora_weight_names: Tuple[Set[str]], lora_type: LoRAType ) -> Optional[str]: """ target_name is name of a given module, @@ -142,9 +146,9 @@ def get_weight_name( Else raise ValueError. """ idx = 0 if lora_type == LoRAType.LORA_A else 1 - for weight_name_pair in lora_weight_names: - if weight_name_pair[idx] in target_name: - return weight_name_pair[idx] + for weight_name in lora_weight_names[idx]: + if weight_name in target_name: + return weight_name raise ValueError( f"Cannot find weight name for {target_name} in {lora_weight_names}" ) diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index 5b57ef203..a574dc27c 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -17,6 +17,7 @@ import logging import math +import re from collections.abc import Iterable from typing import List, Optional, Tuple @@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module): "gate_up_proj": ["gate_proj", "up_proj"], } + lora_pattern = re.compile( + r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)" + ) + def __init__( self, config: PretrainedConfig, @@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module): pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) return pattern.pad_input_tokens(input_ids, mm_inputs) + def should_apply_lora(self, module_name: str) -> Optional[str]: + return self.lora_pattern.match(module_name) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9dba0bb97..c38b69c09 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1473,7 +1473,7 @@ class ServerArgs: self.max_loras_per_batch > 0 # FIXME and (self.lora_paths is None or self.disable_radix_cache) - ), "compatibility of lora and cuda graph and radix attention is in progress" + ), "compatibility of lora and radix attention is in progress" assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative" assert self.gpu_id_step >= 1, "gpu_id_step must be positive"