Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

This commit is contained in:
Lifu Huang
2025-07-13 18:36:01 -07:00
committed by GitHub
parent b5dd5e8741
commit e2ed9d049a
10 changed files with 840 additions and 227 deletions

View File

@@ -16,7 +16,7 @@
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Set, Tuple
from typing import Dict, Iterable, Optional, Set, Tuple
import torch
@@ -53,6 +53,8 @@ class LoRAManager:
lora_backend: str = "triton",
tp_size: int = 1,
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
@@ -62,6 +64,10 @@ class LoRAManager:
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: Optional[int] = max_lora_rank
self.target_modules: Optional[Set[str]] = (
set(target_modules) if target_modules else None
)
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
@@ -153,7 +159,9 @@ class LoRAManager:
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try:
self.configs[lora_name] = LoRAConfig(lora_path)
new_adapter = LoRAConfig(lora_path)
self.validate_new_adapter(lora_name, new_adapter)
self.configs[lora_name] = new_adapter
except Exception as e:
success = False
error_message = (
@@ -168,6 +176,21 @@ class LoRAManager:
error_message=error_message,
)
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
incompatible = self.memory_pool and not self.memory_pool.can_support(
lora_config
)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
@@ -214,7 +237,7 @@ class LoRAManager:
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
if lora_path is not None:
lora = self.loras[lora_path]
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
@@ -319,7 +342,7 @@ class LoRAManager:
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
)
module.set_lora_info(
self.memory_pool.get_tensor(
@@ -351,58 +374,67 @@ class LoRAManager:
i: {} for i in range(self.base_hf_config.num_hidden_layers)
}
# Initialize memory pool
self.memory_pool = LoRAMemoryPool(
self.base_hf_config,
self.max_loras_per_batch,
self.dtype,
self.tp_size,
self.tp_rank,
)
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
# It is initialized lazily when the first LoRA adapter is loaded.
self.memory_pool: Optional[LoRAMemoryPool] = None
def update_state_from_configs(self):
"""
Update the internal state of the LoRAManager based on the current `self.configs`. This method
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
This includes:
- Initializing LoRA adapters if they are not already loaded.
- Collect all LoRA weight names based on the current loaded adapters.
- Lazily monkey-patching the base model to use LoRA layers where applicable.
- Preparing the GPU buffer pool for active LoRA weights.
"""
# Target module names in huggingface lora configs.
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
hf_target_module_names: Set[str] = set()
for config in self.configs.values():
hf_target_module_names.update(config.target_modules)
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
# Loads / unloads LoRA adapters based on the latest configs.
self.update_lora_adapters()
# Apply the latest LoRA configurations to the internal state for inferencing.
self.apply_lora_configs()
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
#
# Please note that the following update operations are "monotonic" by design, meaning that we update
# multiple places to support the new weight names when the first adapter targeting such weight names
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
# list of LoRA weight names is expected to be extremely finite and stable.
self.update_lora_weight_names(hf_target_module_names)
self.update_lora_modules(hf_target_module_names)
self.update_memory_buffers(max_lora_dim)
def apply_lora_configs(self):
"""
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
def update_lora_weight_names(self, hf_target_names: Set[str]):
Notes:
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
early CY25H2.
"""
if self.memory_pool is None:
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
if self.target_modules is None:
self.target_modules = set()
for config in self.configs.values():
self.target_modules.update(config.target_modules)
if self.max_lora_rank is None:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
default=0,
)
self.update_lora_weight_names()
self.update_lora_modules()
self.update_memory_buffers()
else:
# No-op if the memory pool can support the current LoRA configurations.
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
# module is changed once FlashInfer backend is deprecated.
assert self.memory_pool.can_support(self.configs.values()), (
"LoRA memory pool cannot support the current LoRA configuration. "
"This should never happen as we should have validated adapter compatibility. "
"Please create a Github issue to report.",
)
def update_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
# Target lora weight names for lora_a and lora_b modules respectively.
for module in hf_target_names:
lora_A, lora_B = get_normalized_lora_weight_names(module)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
self.lora_weight_names[0].update(lora_A)
self.lora_weight_names[1].update(lora_B)
def update_lora_adapters(self):
"""
@@ -434,21 +466,23 @@ class LoRAManager:
# Additional checks for flashinfer backend
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
if self.lora_backend == "flashinfer":
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
lora_dims = set(x.r for x in self.configs.values())
scalings = set(x.scaling for x in self.loras.values())
assert (
len(lora_dims) == 1 and len(scalings) == 1
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
def update_memory_buffers(self, max_lora_dim: int):
"""
Update the LoRA memory pool buffers based on the current LoRA configurations and update
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
are set or updated.
"""
self.memory_pool.init_buffers(
self.lora_weight_names, self.base_model, max_lora_dim
def update_memory_buffers(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
)
def set_lora_module(self, module_name, module):
@@ -456,11 +490,11 @@ class LoRAManager:
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def update_lora_modules(self, hf_target_names: Set[str]):
def update_lora_modules(self):
# Target module names of customized layers defined in python/sglang/srt/layers
# e.g., {"qkv_proj", "o_proj"}
customized_target_names = get_customized_names_from_hf_names(
hf_target_names, self.base_model
self.target_modules, self.base_model
)
for module_name, module in self.base_model.named_modules():

View File

@@ -1,4 +1,4 @@
from typing import Callable, Dict, List, Optional, Set, Tuple
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
get_hidden_dim,
get_normalized_lora_weight_names,
get_stacked_multiply,
get_weight_name,
)
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
dtype: torch.dtype,
tp_size: int,
tp_rank: int,
max_lora_rank: int,
lora_weight_names: Tuple[Set[str], Set[str]],
base_model: torch.nn.Module,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
self.dtype: torch.dtype = dtype
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
self.max_lora_rank: int = max_lora_rank
# lora weight names for LoRA A and B respectively.
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.init_buffers(base_model)
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
"""
Check if the memory pool can support the given LoRA adapters.
"""
def _can_support(config: LoRAConfig) -> bool:
"""
Check if the memory pool can support a single LoRA adapter.
"""
if config.r > self.max_lora_rank:
return False
weights_a, weights_b = get_normalized_lora_weight_names(
config.target_modules
)
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
self.lora_weight_names[1]
)
if isinstance(config, LoRAConfig):
return _can_support(config)
else:
return all(_can_support(x) for x in config)
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
) -> Tuple[int]:
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
max_lora_dim,
)
def init_buffers(
self,
lora_weight_names: Tuple[Set[str]],
base_model: torch.nn.Module,
max_lora_dim: int,
):
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
def init_buffers(self, base_model: torch.nn.Module):
device = next(base_model.parameters()).device
def update_buffer(
def init_buffer(
buffer: Dict[str, List[torch.Tensor]],
lora_weight_names: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
):
new_weight_names = lora_weight_names - buffer.keys()
for module_name in new_weight_names:
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
for module_name in lora_weight_names:
lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [
torch.empty(
lora_shape,
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
for _ in range(self.num_layer)
]
update_buffer(
init_buffer(
self.A_buffer,
lora_weight_names[0],
self.lora_weight_names[0],
self.get_lora_A_shape,
)
update_buffer(
init_buffer(
self.B_buffer,
lora_weight_names[1],
self.lora_weight_names[1],
self.get_lora_B_shape,
)

View File

@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Set, Tuple
from typing import Iterable, Optional, Set, Tuple
import torch
@@ -106,9 +106,11 @@ def get_hidden_dim(
raise NotImplementedError()
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
def get_normalized_lora_weight_names(
target_modules: Iterable[str],
) -> Tuple[set[str], set[str]]:
"""
Mapping a target module name to names of the normalized LoRA weights.
Mapping a list of target module name to names of the normalized LoRA weights.
Returned tuple contains (name for Lora A, name for Lora B)
"""
params_mapping = {
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
}
stacked = params_mapping.get(name, ([name], [name]))
return stacked
result = (set(), set())
for name in target_modules:
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
result[0].update(lora_a)
result[1].update(lora_b)
return result
def get_stacked_multiply(module_name: str) -> int:

View File

@@ -891,6 +891,8 @@ class ModelRunner:
lora_backend=self.server_args.lora_backend,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.server_args.max_lora_rank,
target_modules=self.server_args.lora_target_modules,
)
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
if result.success:

View File

@@ -134,6 +134,8 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None
# LoRA
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"
@@ -1129,6 +1131,28 @@ class ServerArgs:
)
# LoRA
parser.add_argument(
"--max-lora-rank",
default=ServerArgs.max_lora_rank,
type=int,
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-target-modules",
type=str,
choices=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
nargs="*",
default=None,
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
)
parser.add_argument(
"--lora-paths",
type=str,

View File

@@ -505,6 +505,8 @@ class SRTRunner:
torchao_config: Optional[str] = None,
cuda_graph_max_bs: int = 4,
sleep_on_idle=False,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -543,6 +545,8 @@ class SRTRunner:
cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce,
sleep_on_idle=sleep_on_idle,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
**spec_kwargs,
)