Support GPU pinning for LoRA (#8697)

This commit is contained in:
Lifu Huang
2025-08-06 19:39:45 -07:00
committed by GitHub
parent 6ad6c8c9e6
commit 6210e2c4f0
13 changed files with 425 additions and 134 deletions

View File

@@ -492,12 +492,13 @@ class Engine(EngineBase):
self.tokenizer_manager.get_weights_by_name(obj, None)
)
def load_lora_adapter(self, lora_name: str, lora_path: str):
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
"""Load a new LoRA adapter without re-launching the engine."""
obj = LoadLoRAAdapterReqInput(
lora_name=lora_name,
lora_path=lora_path,
pinned=pinned,
)
loop = asyncio.get_event_loop()

View File

@@ -144,6 +144,7 @@ class LoRAManager:
# keep metadata for displayed messages
self.lora_refs[lora_ref.lora_id] = lora_ref
self.num_pinned_loras += int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
@@ -157,13 +158,22 @@ class LoRAManager:
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
"""
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
memory_pool = getattr(self, "memory_pool", None)
incompatible = memory_pool and not memory_pool.can_support(lora_config)
if incompatible:
raise ValueError(
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
"included in `--enable_lora_modules`."
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
)
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
raise ValueError(
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
)
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
@@ -172,15 +182,17 @@ class LoRAManager:
delete the corresponding LoRA modules.
"""
adapter = self.configs.get(lora_ref.lora_id, None)
adapter = self.configs.get(lora_ref.lora_id)
lora_ref = self.lora_refs.get(lora_ref.lora_id)
assert (
adapter is not None
adapter is not None and lora_ref is not None
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
try:
del self.configs[lora_ref.lora_id]
del self.loras[lora_ref.lora_id]
del self.lora_refs[lora_ref.lora_id]
self.num_pinned_loras -= int(lora_ref.pinned)
except Exception as e:
return self.create_lora_update_result(
success=False,
@@ -189,11 +201,49 @@ class LoRAManager:
return self.create_lora_update_result(success=True)
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
"""
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
"""
if len(lora_ids) > self.max_loras_per_batch:
return False
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
if self.num_pinned_loras == 0:
return True
# counting the number of pinned LoRA adapters in the batch.
pinned_loras_in_batch = 0
for lora_id in lora_ids:
if lora_id is not None:
lora_ref = self.lora_refs.get(lora_id)
assert (
lora_ref is not None
), f"LoRA ID {lora_id} not found in lora_refs."
pinned_loras_in_batch += int(lora_ref.pinned)
assert pinned_loras_in_batch <= self.num_pinned_loras, (
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
)
required_slots = len(lora_ids) - pinned_loras_in_batch
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
return required_slots <= mem_pool_vacancy
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# Load active loras into lora memory pool
cur_uids = set(forward_batch.lora_ids)
assert len(cur_uids) <= self.max_loras_per_batch
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
self.memory_pool.prepare_lora_batch(
cur_uids=cur_uids,
lora_adapters=self.loras,
lora_modules=self.lora_modules,
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
)
# set up batch info shared by all lora modules
bs = forward_batch.batch_size
@@ -366,6 +416,9 @@ class LoRAManager:
# Mapping from LoRA ID to LoRARef object.
self.lora_refs: Dict[str, LoRARef] = {}
# Count of pinned LoRA adapters.
self.num_pinned_loras: int = 0
if lora_paths:
for lora_ref in lora_paths.values():
result = self.load_lora_adapter(lora_ref)
@@ -399,7 +452,7 @@ class LoRAManager:
self.max_lora_rank = max_lora_rank
else:
self.max_lora_rank = max(
[x.hf_config["r"] for x in self.configs.values()],
[x.r for x in self.configs.values()],
default=0,
)

View File

@@ -28,14 +28,15 @@ class LoRARef:
"""
Reference record for a LoRA model.
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
keys (e.g., radix cache).
"""
lora_id: str = field(default_factory=lambda: uuid4().hex)
lora_name: Optional[str] = None
lora_path: Optional[str] = None
pinned: Optional[bool] = None
def __post_init__(self):
if self.lora_id is None:

View File

@@ -1,3 +1,4 @@
import logging
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.lora.lora_registry import LoRARef
from sglang.srt.lora.utils import (
ROW_PARALLELISM_LINEAR_LORA_NAMES,
LoRAType,
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
get_weight_name,
)
logger = logging.getLogger(__name__)
class EmptySlot:
"""
Singleton class to represent an empty slot in the memory pool.
This is used to improve readability by not using special str as a placeholder.
"""
__slots__ = ()
def __repr__(self):
return "|EMPTY|"
def __new__(cls):
if not hasattr(cls, "_instance"):
cls._instance = super().__new__(cls)
return cls._instance
EMPTY_SLOT = EmptySlot()
class LoRAMemoryPool:
"""Class for memory pool management of lora modules"""
@@ -54,9 +78,11 @@ class LoRAMemoryPool:
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
# Buffer idx -> lora uid in memory pool
# All uids are initialized as empty strings for empty buffer slots
# All uids are initialized as `EmptySlot` for empty buffer slots
# Here we don't initialize to None since None is a valid uid
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
EMPTY_SLOT
] * self.max_loras_per_batch
self.init_buffers(base_model)
@@ -154,17 +180,29 @@ class LoRAMemoryPool:
cur_uids: Set[Optional[str]],
lora_adapters: Dict[str, LoRAAdapter],
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
lora_refs: Dict[str, LoRARef],
):
def get_available_buffer_slot():
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == "":
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id
for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
# Evict unneeded lora
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
if uid not in cur_uids:
# Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue
self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id
raise ValueError(

View File

@@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
lora_name: str
# The path of loading.
lora_path: str
# Whether to pin the LoRA adapter in memory.
pinned: bool = False
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
lora_id: Optional[str] = None
@@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
lora_id=self.lora_id,
lora_name=self.lora_name,
lora_path=self.lora_path,
pinned=self.pinned,
)

View File

@@ -1538,14 +1538,11 @@ class Scheduler(
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.enable_lora
and len(
lora_set
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
)
> self.max_loras_per_batch
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
lora_set
| set([req.lora_id for req in adder.can_run_list])
| set([req.lora_id])
):
self.running_batch.batch_is_full = True
break

View File

@@ -1129,6 +1129,7 @@ class TokenizerManager:
new_adapter = LoRARef(
lora_name=obj.lora_name,
lora_path=obj.lora_path,
pinned=obj.pinned,
)
# Trigger the actual loading operation at the backend processes.
@@ -1186,7 +1187,7 @@ class TokenizerManager:
return result
except ValueError as e:
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None

View File

@@ -311,3 +311,6 @@ class TpModelWorker:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
return result
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)

View File

@@ -288,6 +288,9 @@ class TpModelWorkerClient:
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
return self.worker.unload_lora_adapter(recv_req)
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
return self.worker.can_run_lora_batch(lora_ids)
def __delete__(self):
self.input_queue.put((None, None))
self.copy_queue.put((None, None, None))

View File

@@ -2067,21 +2067,23 @@ class ServerArgs:
if self.enable_lora:
# Normalize lora_paths to a dictionary if it is a list.
# TODO (lifuhuang): support specifying pinned adapters in server_args.
if isinstance(self.lora_paths, list):
lora_paths = self.lora_paths
self.lora_paths = {}
for lora_path in lora_paths:
if "=" in lora_path:
name, path = lora_path.split("=", 1)
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
self.lora_paths[name] = LoRARef(
lora_name=name, lora_path=path, pinned=False
)
else:
self.lora_paths[lora_path] = LoRARef(
lora_name=lora_path,
lora_path=lora_path,
lora_name=lora_path, lora_path=lora_path, pinned=False
)
elif isinstance(self.lora_paths, dict):
self.lora_paths = {
k: LoRARef(lora_name=k, lora_path=v)
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
for k, v in self.lora_paths.items()
}
elif self.lora_paths is None:

View File

@@ -568,8 +568,8 @@ class SRTRunner:
else:
self.tokenizer = None
def load_lora_adapter(self, lora_name: str, lora_path: str):
return self.engine.load_lora_adapter(lora_name, lora_path)
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
def unload_lora_adapter(self, lora_name: str):
return self.engine.unload_lora_adapter(lora_name)