Implement LRU eviction policy for LoRA adapters (#11041)

This commit is contained in:
Chenxi Li
2025-10-13 20:18:25 -07:00
committed by GitHub
parent 88a6f9dab5
commit 28f80b1244
9 changed files with 399 additions and 15 deletions

View File

@@ -0,0 +1,139 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Eviction policies for LoRA adapter memory management.
"""
import logging
import time
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger(__name__)
class EvictionPolicy(ABC):
"""Abstract base class for LoRA adapter eviction policies."""
@abstractmethod
def mark_used(self, uid: Optional[str]) -> None:
"""Marks an adapter as used."""
pass
@abstractmethod
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Selects an adapter to evict from candidates."""
pass
@abstractmethod
def remove(self, uid: Optional[str]) -> None:
"""Removes an adapter from the policy's tracking."""
pass
class LRUEvictionPolicy(EvictionPolicy):
"""LRU eviction policy - evicts the least recently used adapter."""
def __init__(self):
self.access_order = OrderedDict() # key=uid, value=last_access_time
self.total_accesses = 0
self.eviction_count = 0
def mark_used(self, uid: Optional[str]) -> None:
if uid is not None:
current_time = time.monotonic()
# Remove and re-add to move to end (most recent)
self.access_order.pop(uid, None)
self.access_order[uid] = current_time
self.total_accesses += 1
logger.debug(f"LoRA {uid} marked as used at {current_time}")
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Select the least recently used adapter from candidates."""
# Base model (currently None, will be replaced with special UID in future)
# always has lowest priority - evict it first if available
BASE_MODEL_UID = None # TODO: Replace with special UID constant
if BASE_MODEL_UID in candidates:
logger.debug(f"Selected base model for eviction (LRU)")
self.eviction_count += 1
return BASE_MODEL_UID
# Iterate through access_order (oldest first) to find LRU victim
for uid in list(self.access_order.keys()):
if uid in candidates:
logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
self.eviction_count += 1
return uid
# Should never reach here if candidates is non-empty
assert False, f"Failed to select LRU victim from candidates: {candidates}"
def remove(self, uid: Optional[str]) -> None:
if uid is not None:
self.access_order.pop(uid, None)
logger.debug(f"Removed LoRA {uid} from LRU tracking")
class FIFOEvictionPolicy(EvictionPolicy):
"""FIFO eviction policy - for backward compatibility."""
def __init__(self):
self.insertion_order = (
OrderedDict()
) # key=uid, OrderedDict maintains insertion order
self.eviction_count = 0
def mark_used(self, uid: Optional[str]) -> None:
"""For FIFO, we only track insertion order (not access time)."""
if uid is not None and uid not in self.insertion_order:
self.insertion_order[uid] = (
True # Value unused, OrderedDict tracks insertion order
)
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
"""Select the first inserted adapter from candidates."""
# Base model (currently None, will be replaced with special UID in future)
# always has lowest priority - evict it first if available
BASE_MODEL_UID = None # TODO: Replace with special UID constant
if BASE_MODEL_UID in candidates:
logger.debug(f"Selected base model for eviction (FIFO)")
self.eviction_count += 1
return BASE_MODEL_UID
# Iterate through insertion_order (oldest first) to find FIFO victim
for uid in list(self.insertion_order.keys()):
if uid in candidates:
logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
self.eviction_count += 1
return uid
# Should never reach here if candidates is non-empty
assert False, f"Failed to select FIFO victim from candidates: {candidates}"
def remove(self, uid: Optional[str]) -> None:
if uid is not None:
self.insertion_order.pop(uid, None)
def get_eviction_policy(policy_name: str) -> EvictionPolicy:
"""Factory function to create eviction policy instances."""
policies = {
"fifo": FIFOEvictionPolicy,
"lru": LRUEvictionPolicy,
}
if policy_name not in policies:
raise ValueError(f"Unknown eviction policy: {policy_name}")
return policies[policy_name]()

View File

@@ -68,6 +68,9 @@ class LoRAManager:
self.tp_size: int = tp_size
self.tp_rank: int = tp_rank
# Store eviction policy from server args
self.eviction_policy = server_args.lora_eviction_policy
# LoRA backend for running sgemm kernels
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
backend_type = get_backend_from_name(lora_backend)
@@ -131,6 +134,16 @@ class LoRAManager:
lora_ref.lora_id not in self.loras
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
return self.create_lora_update_result(
success=False,
error_message=(
f"Already have {self.num_pinned_loras} pinned adapters, "
f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). "
f"Please unpin some adapters or increase max_loras_per_batch."
),
)
try:
# load configs
new_adapter = LoRAConfig(lora_ref.lora_path)
@@ -420,6 +433,7 @@ class LoRAManager:
max_lora_rank=self.max_lora_rank,
target_modules=self.target_modules,
base_model=self.base_model,
eviction_policy=self.eviction_policy,
)
def set_lora_module(self, module_name, module):

View File

@@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
from sglang.srt.distributed import divide
from sglang.srt.lora.eviction_policy import get_eviction_policy
from sglang.srt.lora.layers import BaseLayerWithLoRA
from sglang.srt.lora.lora import LoRAAdapter
from sglang.srt.lora.lora_config import LoRAConfig
@@ -54,6 +55,7 @@ class LoRAMemoryPool:
max_lora_rank: int,
target_modules: Set[str],
base_model: torch.nn.Module,
eviction_policy: str,
):
self.base_hf_config: AutoConfig = base_hf_config
self.num_layer: int = base_hf_config.num_hidden_layers
@@ -64,6 +66,9 @@ class LoRAMemoryPool:
self.max_lora_rank: int = max_lora_rank
self.target_modules: Set[str] = target_modules
# Initialize eviction policy
self.eviction_policy = get_eviction_policy(eviction_policy)
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
# A_buffer contains num_layer number of row-major tensors with shape
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
@@ -189,31 +194,50 @@ class LoRAMemoryPool:
lora_refs: Dict[str, LoRARef],
):
def get_available_buffer_slot():
# 1. Prioritize empty slots
for buffer_id in range(self.max_loras_per_batch):
# Prioritize empty slots
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
return buffer_id
# 2. Memory pool is full, need to evict using policy
candidates = set()
for buffer_id in range(self.max_loras_per_batch):
uid = self.buffer_id_to_uid[buffer_id]
# Evict unneeded lora
if uid not in cur_uids:
# Skip pinned LoRAs
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref is not None and lora_ref.pinned:
continue
# Skip if this adapter is needed by current batch
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
if uid in cur_uids:
continue
self.uid_to_buffer_id.pop(uid)
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
return buffer_id
# Skip if this adapter is pinned (base model cannot be pinned, so can be evicted)
if uid is not None:
lora_ref = lora_refs.get(uid)
if lora_ref and lora_ref.pinned:
continue
candidates.add(uid)
raise ValueError(
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
if not candidates:
raise ValueError(
"No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch."
)
# Select victim using eviction policy
victim_uid = self.eviction_policy.select_victim(candidates)
# Evict the selected victim
victim_buffer_id = self.uid_to_buffer_id[victim_uid]
self.uid_to_buffer_id.pop(victim_uid)
self.eviction_policy.remove(victim_uid)
self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
logger.debug(
f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
)
return victim_buffer_id
# Mark all adapters in current batch as used (for LRU tracking)
for uid in cur_uids:
self.eviction_policy.mark_used(uid)
for uid in cur_uids:
if uid not in self.uid_to_buffer_id:

View File

@@ -122,6 +122,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
DEFAULT_LORA_EVICTION_POLICY = "lru"
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
@@ -304,6 +306,7 @@ class ServerArgs:
] = None
max_loaded_loras: Optional[int] = None
max_loras_per_batch: int = 8
lora_eviction_policy: str = DEFAULT_LORA_EVICTION_POLICY
lora_backend: str = "triton"
max_lora_chunk_size: Optional[int] = 16
@@ -2127,6 +2130,13 @@ class ServerArgs:
default=ServerArgs.max_loaded_loras,
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
)
parser.add_argument(
"--lora-eviction-policy",
type=str,
default=DEFAULT_LORA_EVICTION_POLICY,
choices=["lru", "fifo"],
help="LoRA adapter eviction policy when memory pool is full. 'lru': Least Recently Used (default, better cache efficiency). 'fifo': First-In-First-Out.",
)
parser.add_argument(
"--lora-backend",
type=str,

View File

@@ -519,6 +519,7 @@ class SRTRunner:
lora_target_modules: Optional[List[str]] = None,
enable_lora: Optional[bool] = None,
max_loaded_loras: Optional[int] = None,
lora_eviction_policy: str = "lru",
):
self.model_type = model_type
self.is_generation = model_type == "generation"
@@ -565,6 +566,7 @@ class SRTRunner:
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
max_loaded_loras=max_loaded_loras,
lora_eviction_policy=lora_eviction_policy,
**spec_kwargs,
)