Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
from typing import Dict, Set, Tuple
|
||||
from typing import Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,6 +53,8 @@ class LoRAManager:
|
||||
lora_backend: str = "triton",
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -62,6 +64,10 @@ class LoRAManager:
|
||||
self.device: torch.device = next(self.base_model.parameters()).device
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: Optional[int] = max_lora_rank
|
||||
self.target_modules: Optional[Set[str]] = (
|
||||
set(target_modules) if target_modules else None
|
||||
)
|
||||
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
@@ -153,7 +159,9 @@ class LoRAManager:
|
||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||
|
||||
try:
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
new_adapter = LoRAConfig(lora_path)
|
||||
self.validate_new_adapter(lora_name, new_adapter)
|
||||
self.configs[lora_name] = new_adapter
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = (
|
||||
@@ -168,6 +176,21 @@ class LoRAManager:
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
||||
"""
|
||||
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||
"""
|
||||
|
||||
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
||||
lora_config
|
||||
)
|
||||
if incompatible:
|
||||
raise ValueError(
|
||||
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
|
||||
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
|
||||
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
@@ -214,7 +237,7 @@ class LoRAManager:
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
lora_ranks[weight_indices[i]] = lora.config.r
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||
@@ -319,7 +342,7 @@ class LoRAManager:
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
@@ -351,58 +374,67 @@ class LoRAManager:
|
||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config,
|
||||
self.max_loras_per_batch,
|
||||
self.dtype,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
||||
# It is initialized lazily when the first LoRA adapter is loaded.
|
||||
self.memory_pool: Optional[LoRAMemoryPool] = None
|
||||
|
||||
def update_state_from_configs(self):
|
||||
"""
|
||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||
|
||||
This includes:
|
||||
- Initializing LoRA adapters if they are not already loaded.
|
||||
- Collect all LoRA weight names based on the current loaded adapters.
|
||||
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
||||
- Preparing the GPU buffer pool for active LoRA weights.
|
||||
"""
|
||||
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
hf_target_module_names: Set[str] = set()
|
||||
for config in self.configs.values():
|
||||
hf_target_module_names.update(config.target_modules)
|
||||
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
|
||||
# Loads / unloads LoRA adapters based on the latest configs.
|
||||
self.update_lora_adapters()
|
||||
# Apply the latest LoRA configurations to the internal state for inferencing.
|
||||
self.apply_lora_configs()
|
||||
|
||||
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
||||
#
|
||||
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
||||
# multiple places to support the new weight names when the first adapter targeting such weight names
|
||||
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
||||
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
||||
# list of LoRA weight names is expected to be extremely finite and stable.
|
||||
self.update_lora_weight_names(hf_target_module_names)
|
||||
self.update_lora_modules(hf_target_module_names)
|
||||
self.update_memory_buffers(max_lora_dim)
|
||||
def apply_lora_configs(self):
|
||||
"""
|
||||
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
||||
|
||||
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
||||
Notes:
|
||||
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
||||
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
||||
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
||||
early CY25H2.
|
||||
"""
|
||||
|
||||
if self.memory_pool is None:
|
||||
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
||||
if self.target_modules is None:
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
self.target_modules.update(config.target_modules)
|
||||
|
||||
if self.max_lora_rank is None:
|
||||
self.max_lora_rank = max(
|
||||
[x.hf_config["r"] for x in self.configs.values()],
|
||||
default=0,
|
||||
)
|
||||
|
||||
self.update_lora_weight_names()
|
||||
self.update_lora_modules()
|
||||
self.update_memory_buffers()
|
||||
else:
|
||||
# No-op if the memory pool can support the current LoRA configurations.
|
||||
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
||||
# module is changed once FlashInfer backend is deprecated.
|
||||
assert self.memory_pool.can_support(self.configs.values()), (
|
||||
"LoRA memory pool cannot support the current LoRA configuration. "
|
||||
"This should never happen as we should have validated adapter compatibility. "
|
||||
"Please create a Github issue to report.",
|
||||
)
|
||||
|
||||
def update_lora_weight_names(self):
|
||||
"""
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
for module in hf_target_names:
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
|
||||
def update_lora_adapters(self):
|
||||
"""
|
||||
@@ -434,21 +466,23 @@ class LoRAManager:
|
||||
# Additional checks for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
if self.lora_backend == "flashinfer":
|
||||
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
||||
lora_dims = set(x.r for x in self.configs.values())
|
||||
scalings = set(x.scaling for x in self.loras.values())
|
||||
assert (
|
||||
len(lora_dims) == 1 and len(scalings) == 1
|
||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||
|
||||
def update_memory_buffers(self, max_lora_dim: int):
|
||||
"""
|
||||
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
||||
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
||||
are set or updated.
|
||||
"""
|
||||
|
||||
self.memory_pool.init_buffers(
|
||||
self.lora_weight_names, self.base_model, max_lora_dim
|
||||
def update_memory_buffers(self):
|
||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
base_hf_config=self.base_hf_config,
|
||||
max_loras_per_batch=self.max_loras_per_batch,
|
||||
dtype=self.dtype,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
lora_weight_names=self.lora_weight_names,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
@@ -456,11 +490,11 @@ class LoRAManager:
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def update_lora_modules(self, hf_target_names: Set[str]):
|
||||
def update_lora_modules(self):
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
hf_target_names, self.base_model
|
||||
self.target_modules, self.base_model
|
||||
)
|
||||
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.utils import (
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||
LoRAType,
|
||||
get_hidden_dim,
|
||||
get_normalized_lora_weight_names,
|
||||
get_stacked_multiply,
|
||||
get_weight_name,
|
||||
)
|
||||
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
max_lora_rank: int,
|
||||
lora_weight_names: Tuple[Set[str], Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
|
||||
self.dtype: torch.dtype = dtype
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: int = max_lora_rank
|
||||
|
||||
# lora weight names for LoRA A and B respectively.
|
||||
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
|
||||
# Here we don't initialize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
self.init_buffers(base_model)
|
||||
|
||||
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
|
||||
"""
|
||||
Check if the memory pool can support the given LoRA adapters.
|
||||
"""
|
||||
|
||||
def _can_support(config: LoRAConfig) -> bool:
|
||||
"""
|
||||
Check if the memory pool can support a single LoRA adapter.
|
||||
"""
|
||||
if config.r > self.max_lora_rank:
|
||||
return False
|
||||
weights_a, weights_b = get_normalized_lora_weight_names(
|
||||
config.target_modules
|
||||
)
|
||||
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
||||
self.lora_weight_names[1]
|
||||
)
|
||||
|
||||
if isinstance(config, LoRAConfig):
|
||||
return _can_support(config)
|
||||
else:
|
||||
return all(_can_support(x) for x in config)
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
) -> Tuple[int]:
|
||||
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
|
||||
max_lora_dim,
|
||||
)
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Tuple[Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
max_lora_dim: int,
|
||||
):
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||
def init_buffers(self, base_model: torch.nn.Module):
|
||||
device = next(base_model.parameters()).device
|
||||
|
||||
def update_buffer(
|
||||
def init_buffer(
|
||||
buffer: Dict[str, List[torch.Tensor]],
|
||||
lora_weight_names: Set[str],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||
):
|
||||
new_weight_names = lora_weight_names - buffer.keys()
|
||||
for module_name in new_weight_names:
|
||||
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
||||
for module_name in lora_weight_names:
|
||||
lora_shape = get_lora_shape_fn(
|
||||
module_name, base_model, self.max_lora_rank
|
||||
)
|
||||
buffer[module_name] = [
|
||||
torch.empty(
|
||||
lora_shape,
|
||||
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
|
||||
update_buffer(
|
||||
init_buffer(
|
||||
self.A_buffer,
|
||||
lora_weight_names[0],
|
||||
self.lora_weight_names[0],
|
||||
self.get_lora_A_shape,
|
||||
)
|
||||
|
||||
update_buffer(
|
||||
init_buffer(
|
||||
self.B_buffer,
|
||||
lora_weight_names[1],
|
||||
self.lora_weight_names[1],
|
||||
self.get_lora_B_shape,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
def get_normalized_lora_weight_names(
|
||||
target_modules: Iterable[str],
|
||||
) -> Tuple[set[str], set[str]]:
|
||||
"""
|
||||
Mapping a target module name to names of the normalized LoRA weights.
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
}
|
||||
stacked = params_mapping.get(name, ([name], [name]))
|
||||
return stacked
|
||||
|
||||
result = (set(), set())
|
||||
for name in target_modules:
|
||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||
result[0].update(lora_a)
|
||||
result[1].update(lora_b)
|
||||
return result
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
|
||||
@@ -891,6 +891,8 @@ class ModelRunner:
|
||||
lora_backend=self.server_args.lora_backend,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
max_lora_rank=self.server_args.max_lora_rank,
|
||||
target_modules=self.server_args.lora_target_modules,
|
||||
)
|
||||
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||
if result.success:
|
||||
|
||||
@@ -134,6 +134,8 @@ class ServerArgs:
|
||||
preferred_sampling_params: Optional[str] = None
|
||||
|
||||
# LoRA
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[List[str]] = None
|
||||
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
@@ -1129,6 +1131,28 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--max-lora-rank",
|
||||
default=ServerArgs.max_lora_rank,
|
||||
type=int,
|
||||
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
choices=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
type=str,
|
||||
|
||||
@@ -505,6 +505,8 @@ class SRTRunner:
|
||||
torchao_config: Optional[str] = None,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
sleep_on_idle=False,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -543,6 +545,8 @@ class SRTRunner:
|
||||
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
sleep_on_idle=sleep_on_idle,
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
**spec_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user