Support GPU pinning for LoRA (#8697)

This commit is contained in:
Lifu Huang
2025-08-06 19:39:45 -07:00
committed by GitHub
parent 6ad6c8c9e6
commit 6210e2c4f0
13 changed files with 425 additions and 134 deletions

View File

@@ -144,6 +144,7 @@ class LoRAManager:
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
@@ -157,13 +158,22 @@ class LoRAManager:
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
)
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
raise ValueError(
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
@@ -172,15 +182,17 @@ class LoRAManager:
delete the corresponding LoRA modules.
"""
adapter = self.configs.get(lora_ref.lora_id, None)
adapter = self.configs.get(lora_ref.lora_id)
lora_ref = self.lora_refs.get(lora_ref.lora_id)
assert (
adapter is not None
adapter is not None and lora_ref is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
self.num_pinned_loras -= int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
@@ -189,11 +201,49 @@ class LoRAManager:
return self.create_lora_update_result(success=True)
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
if len(lora_ids) > self.max_loras_per_batch:
return False
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
if self.num_pinned_loras == 0:
return True
# counting the number of pinned LoRA adapters in the batch.
pinned_loras_in_batch = 0
for lora_id in lora_ids:
if lora_id is not None:
lora_ref = self.lora_refs.get(lora_id)
assert (
lora_ref is not None
), f"LoRA ID {lora_id} not found in lora_refs."
pinned_loras_in_batch += int(lora_ref.pinned)
assert pinned_loras_in_batch <= self.num_pinned_loras, (
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
)
required_slots = len(lora_ids) - pinned_loras_in_batch
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
self.memory_pool.prepare_lora_batch(
cur_uids=cur_uids,
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
)
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
@@ -366,6 +416,9 @@ class LoRAManager:
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Count of pinned LoRA adapters.
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
@@ -399,7 +452,7 @@ class LoRAManager:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
[x.r for x in self.configs.values()],
default=0,
)

View File

@@ -28,14 +28,15 @@ class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
pinned: Optional[bool] = None
def __post_init__(self):
if self.lora_id is None:

View File

@@ -1,3 +1,4 @@
import logging
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
get_weight_name,
)
logger = logging.getLogger(__name__)
class EmptySlot:
"""
Singleton class to represent an empty slot in the memory pool.
This is used to improve readability by not using special str as a placeholder.
"""
__slots__ = ()
def __repr__(self):
return "|EMPTY|"
def __new__(cls):
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance
EMPTY_SLOT = EmptySlot()
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
@@ -54,9 +78,11 @@ class LoRAMemoryPool:
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initialized as empty strings for empty buffer slots
# All uids are initialized as `EmptySlot` for empty buffer slots
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
EMPTY_SLOT
] * self.max_loras_per_batch
self.init_buffers(base_model)
@@ -154,17 +180,29 @@ class LoRAMemoryPool:
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
lora_refs: Dict[str, LoRARef],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id
for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
if uid not in cur_uids:
# Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue
self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id
raise ValueError(