Files
sglang/python/sglang/srt/lora/lora_manager.py

502 lines
20 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
# and "Punica: Multi-Tenant LoRA Serving"
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
import torch
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.mem_pool import LoRAMemoryPool
from sglang.srt.lora.utils import (
LoRABatchInfo,
LoRAType,
get_layer_id,
get_normalized_lora_weight_names,
get_weight_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import replace_submodule
logger = logging.getLogger(__name__)
class LoRAManager:
def __init__(
self,
base_model: torch.nn.Module,
base_hf_config: AutoConfig,
max_loras_per_batch: int,
load_config: LoadConfig,
dtype: torch.dtype,
lora_backend: str = "triton",
tp_size: int = 1,
tp_rank: int = 0,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
self.base_model: torch.nn.Module = base_model
self.base_hf_config: AutoConfig = base_hf_config
self.max_loras_per_batch: int = max_loras_per_batch
self.load_config: LoadConfig = load_config
self.dtype: torch.dtype = dtype
self.device: torch.device = next(self.base_model.parameters()).device
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
# Initialize mutable internal state of the LoRAManager.
self.init_state(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
lora_paths=lora_paths,
)
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
with torch.device("cuda"):
self.cuda_graph_batch_info = LoRABatchInfo(
bs=self.max_bs_in_cuda_graph,
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
seg_indptr=torch.zeros(
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
),
max_len=1,
weight_indices=torch.zeros(
self.max_bs_in_cuda_graph, dtype=torch.int32
),
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
)
# Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
# across batches.
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph].fill_(1)
torch.cumsum(
self.cuda_graph_batch_info.seg_lens[: self.max_bs_in_cuda_graph],
dim=0,
out=self.cuda_graph_batch_info.seg_indptr[
1 : self.max_bs_in_cuda_graph + 1
],
)
def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
success=success,
error_message=error_message,
loaded_adapters={
lora_ref.lora_name: lora_ref.lora_path
for lora_ref in self.lora_refs.values()
},
)
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
"""
assert (
lora_ref.lora_name is not None and lora_ref.lora_path is not None
), "LoRARef must have both lora_name and lora_path set for loading."
assert (
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
try:
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
self.validate_new_adapter(new_adapter, lora_ref)
self.configs[lora_ref.lora_id] = new_adapter
# load weights
self.load_lora_weights(lora_ref)
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result(success=True)
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
"""
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
)
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
raise ValueError(
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
"""
adapter = self.configs.get(lora_ref.lora_id)
lora_ref = self.lora_refs.get(lora_ref.lora_id)
assert (
adapter is not None and lora_ref is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
self.num_pinned_loras -= int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
error_message=str(e),
)
return self.create_lora_update_result(success=True)
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
if len(lora_ids) > self.max_loras_per_batch:
return False
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
if self.num_pinned_loras == 0:
return True
# counting the number of pinned LoRA adapters in the batch.
pinned_loras_in_batch = 0
for lora_id in lora_ids:
if lora_id is not None:
lora_ref = self.lora_refs.get(lora_id)
assert (
lora_ref is not None
), f"LoRA ID {lora_id} not found in lora_refs."
pinned_loras_in_batch += int(lora_ref.pinned)
assert pinned_loras_in_batch <= self.num_pinned_loras, (
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
)
required_slots = len(lora_ids) - pinned_loras_in_batch
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(
cur_uids=cur_uids,
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
)
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
def transfer_adapter_info(
weight_indices_out: torch.Tensor,
lora_ranks_out: torch.Tensor,
scalings_out: torch.Tensor,
):
"""
Transfer adapter metadata (weight indices, LoRA rank, scalings) from host
to device (CUDA) asynchronously.
"""
weight_indices = [0] * len(forward_batch.lora_ids)
lora_ranks = [0] * self.max_loras_per_batch
scalings = [0] * self.max_loras_per_batch
for i, uid in enumerate(forward_batch.lora_ids):
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
if uid is not None:
lora = self.loras[uid]
lora_ranks[weight_indices[i]] = lora.config.r
scalings[weight_indices[i]] = lora.scaling
# Use pinned memory to avoid synchronizations during host-to-device transfer
weight_indices_tensor = torch.tensor(
weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
)
lora_ranks_tensor = torch.tensor(
lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
)
scalings_tensor = torch.tensor(
scalings, dtype=torch.float, pin_memory=True, device="cpu"
)
# Copy to device tensors asynchronously
weight_indices_out[:bs].copy_(weight_indices_tensor, non_blocking=True)
lora_ranks_out[: self.max_loras_per_batch].copy_(
lora_ranks_tensor, non_blocking=True
)
scalings_out[: self.max_loras_per_batch].copy_(
scalings_tensor, non_blocking=True
)
if (
hasattr(self, "max_bs_in_cuda_graph")
and bs <= self.max_bs_in_cuda_graph
and forward_batch.forward_mode.is_cuda_graph()
):
# Do in-place updates when CUDA graph is enabled and the batch forward mode
# could use CUDA graph.
transfer_adapter_info(
self.cuda_graph_batch_info.weight_indices,
self.cuda_graph_batch_info.lora_ranks,
self.cuda_graph_batch_info.scalings,
)
self.cuda_graph_batch_info.bs = bs
self.cuda_graph_batch_info.max_len = 1
batch_info = self.cuda_graph_batch_info
else:
weight_indices = torch.empty((bs,), dtype=torch.int32, device=self.device)
lora_ranks = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.int64, device=self.device
)
scalings = torch.zeros(
(self.max_loras_per_batch,), dtype=torch.float, device=self.device
)
transfer_adapter_info(
weight_indices,
lora_ranks,
scalings,
)
seg_lens = (
forward_batch.extend_seq_lens
if forward_batch.forward_mode.is_extend()
else torch.ones(bs, device=self.device)
)
max_len = (
# Calculate max_len from the CPU copy to avoid D2H transfer.
max(forward_batch.extend_seq_lens_cpu)
if forward_batch.forward_mode.is_extend()
else 1
)
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
batch_info = LoRABatchInfo(
bs=bs,
seg_lens=seg_lens,
seg_indptr=seg_indptr,
max_len=max_len,
weight_indices=weight_indices,
lora_ranks=lora_ranks,
scalings=scalings,
)
self.lora_backend.set_batch_info(batch_info)
def update_lora_info(self):
"""
Update all LoRA modules to associate them with the latest memory buffer.
"""
for layer_id, layer_modules in enumerate(self.lora_modules):
for module_name, module in layer_modules.items():
weight_name = get_weight_name(
module_name, self.memory_pool.lora_weight_names
)
module.set_lora_info(
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
)
def init_state(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
lora_paths: Optional[Dict[str, LoRARef]] = None,
):
"""
Initialize the internal (mutable) state of the LoRAManager.
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
the target modules and max_lora_rank.
"""
assert lora_paths or (
max_lora_rank is not None and target_modules is not None
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
self.init_lora_adapters(lora_paths)
self.init_lora_shapes(
max_lora_rank=max_lora_rank,
target_modules=target_modules,
)
self.init_lora_weight_names()
self.init_lora_modules()
self.init_memory_pool()
self.update_lora_info()
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
# Configs of all active LoRA adapters, indexed by LoRA ID.
self.configs: Dict[str, LoRAConfig] = {}
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
self.loras: Dict[str, LoRAAdapter] = {}
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Count of pinned LoRA adapters.
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
if not result.success:
raise RuntimeError(
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
)
def init_lora_shapes(
self,
max_lora_rank: Optional[int] = None,
target_modules: Optional[Iterable[str]] = None,
):
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
if target_modules is not None:
self.target_modules = set(target_modules)
else:
self.target_modules = set()
for config in self.configs.values():
if not isinstance(config.target_modules, list):
raise ValueError(
f"SGLang currently only supports inferring LoRA target modules when a list of "
"suffixes is provided in `target_modules` field of PEFT config. Please explicitly "
"specify `--lora-target-modules` during server startup. You can specify `all` to "
"enable all support modules types. "
)
self.target_modules.update(config.target_modules)
if max_lora_rank is not None:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.r for x in self.configs.values()],
default=0,
)
def init_lora_weight_names(self):
"""
Add new LoRA weight names if needed based on the current `self.configs`.
"""
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
self.target_modules
)
def load_lora_weights(self, lora_ref: LoRARef):
"""
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
"""
lora_adapter = LoRAAdapter(
lora_ref.lora_id,
self.configs[lora_ref.lora_id],
self.base_hf_config,
self.load_config,
self.lora_backend,
)
lora_adapter.initialize_weights()
self.loras[lora_ref.lora_id] = lora_adapter
def init_memory_pool(self):
"""(Re)initialize the LoRA memory pool based on the current configurations."""
self.memory_pool = LoRAMemoryPool(
base_hf_config=self.base_hf_config,
max_loras_per_batch=self.max_loras_per_batch,
dtype=self.dtype,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
max_lora_rank=self.max_lora_rank,
lora_weight_names=self.lora_weight_names,
base_model=self.base_model,
)
def set_lora_module(self, module_name, module):
lora_module = get_lora_layer(module, self.lora_backend)
replace_submodule(self.base_model, module_name, lora_module)
return lora_module
def init_lora_modules(self):
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
{} for _ in range(self.base_hf_config.num_hidden_layers)
]
for module_name, module in self.base_model.named_modules():
# TODO (lifuhuang): in the future, we should consider generalizing the
# should_apply_lora function to support mapping by full module name instead
# of just the last part (e.g., "qkv_proj") to support scenarios with multiple
# attention stacks (e.g., multimodal models).
# See: https://github.com/sgl-project/sglang/issues/6608
if getattr(
self.base_model, "should_apply_lora", None
) and not self.base_model.should_apply_lora(module_name):
continue
# The module should be converted if it is included in target_names
if module_name.split(".")[-1] in self.lora_weight_names:
layer_id = get_layer_id(module_name)
self.lora_modules[layer_id][module_name] = self.set_lora_module(
module_name, module
)