Implement LRU eviction policy for LoRA adapters (#11041)
This commit is contained in:
@@ -35,6 +35,8 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"* `lora_eviction_policy`: LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (default, better cache efficiency). `fifo`: First-In-First-Out.\n",
|
||||||
|
"\n",
|
||||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. Currently we support Triton LoRA backend (`triton`) and Chunked SGMV backend (`csgmv`). In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||||
|
|||||||
@@ -213,6 +213,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
|
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
|
||||||
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
||||||
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
|
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
|
||||||
|
| `--lora-eviction-policy` | LoRA adapter eviction policy when GPU memory pool is full. `lru`: Least Recently Used (better cache efficiency). `fifo`: First-In-First-Out. | lru |
|
||||||
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
||||||
|
|
||||||
## Kernel backend
|
## Kernel backend
|
||||||
|
|||||||
139
python/sglang/srt/lora/eviction_policy.py
Normal file
139
python/sglang/srt/lora/eviction_policy.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Eviction policies for LoRA adapter memory management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EvictionPolicy(ABC):
|
||||||
|
"""Abstract base class for LoRA adapter eviction policies."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def mark_used(self, uid: Optional[str]) -> None:
|
||||||
|
"""Marks an adapter as used."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
|
||||||
|
"""Selects an adapter to evict from candidates."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def remove(self, uid: Optional[str]) -> None:
|
||||||
|
"""Removes an adapter from the policy's tracking."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class LRUEvictionPolicy(EvictionPolicy):
|
||||||
|
"""LRU eviction policy - evicts the least recently used adapter."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.access_order = OrderedDict() # key=uid, value=last_access_time
|
||||||
|
self.total_accesses = 0
|
||||||
|
self.eviction_count = 0
|
||||||
|
|
||||||
|
def mark_used(self, uid: Optional[str]) -> None:
|
||||||
|
if uid is not None:
|
||||||
|
current_time = time.monotonic()
|
||||||
|
# Remove and re-add to move to end (most recent)
|
||||||
|
self.access_order.pop(uid, None)
|
||||||
|
self.access_order[uid] = current_time
|
||||||
|
self.total_accesses += 1
|
||||||
|
logger.debug(f"LoRA {uid} marked as used at {current_time}")
|
||||||
|
|
||||||
|
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
|
||||||
|
"""Select the least recently used adapter from candidates."""
|
||||||
|
# Base model (currently None, will be replaced with special UID in future)
|
||||||
|
# always has lowest priority - evict it first if available
|
||||||
|
BASE_MODEL_UID = None # TODO: Replace with special UID constant
|
||||||
|
if BASE_MODEL_UID in candidates:
|
||||||
|
logger.debug(f"Selected base model for eviction (LRU)")
|
||||||
|
self.eviction_count += 1
|
||||||
|
return BASE_MODEL_UID
|
||||||
|
|
||||||
|
# Iterate through access_order (oldest first) to find LRU victim
|
||||||
|
for uid in list(self.access_order.keys()):
|
||||||
|
if uid in candidates:
|
||||||
|
logger.debug(f"Selected LoRA {uid} for eviction (LRU)")
|
||||||
|
self.eviction_count += 1
|
||||||
|
return uid
|
||||||
|
|
||||||
|
# Should never reach here if candidates is non-empty
|
||||||
|
assert False, f"Failed to select LRU victim from candidates: {candidates}"
|
||||||
|
|
||||||
|
def remove(self, uid: Optional[str]) -> None:
|
||||||
|
if uid is not None:
|
||||||
|
self.access_order.pop(uid, None)
|
||||||
|
logger.debug(f"Removed LoRA {uid} from LRU tracking")
|
||||||
|
|
||||||
|
|
||||||
|
class FIFOEvictionPolicy(EvictionPolicy):
|
||||||
|
"""FIFO eviction policy - for backward compatibility."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.insertion_order = (
|
||||||
|
OrderedDict()
|
||||||
|
) # key=uid, OrderedDict maintains insertion order
|
||||||
|
self.eviction_count = 0
|
||||||
|
|
||||||
|
def mark_used(self, uid: Optional[str]) -> None:
|
||||||
|
"""For FIFO, we only track insertion order (not access time)."""
|
||||||
|
if uid is not None and uid not in self.insertion_order:
|
||||||
|
self.insertion_order[uid] = (
|
||||||
|
True # Value unused, OrderedDict tracks insertion order
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_victim(self, candidates: Set[Optional[str]]) -> Optional[str]:
|
||||||
|
"""Select the first inserted adapter from candidates."""
|
||||||
|
# Base model (currently None, will be replaced with special UID in future)
|
||||||
|
# always has lowest priority - evict it first if available
|
||||||
|
BASE_MODEL_UID = None # TODO: Replace with special UID constant
|
||||||
|
if BASE_MODEL_UID in candidates:
|
||||||
|
logger.debug(f"Selected base model for eviction (FIFO)")
|
||||||
|
self.eviction_count += 1
|
||||||
|
return BASE_MODEL_UID
|
||||||
|
|
||||||
|
# Iterate through insertion_order (oldest first) to find FIFO victim
|
||||||
|
for uid in list(self.insertion_order.keys()):
|
||||||
|
if uid in candidates:
|
||||||
|
logger.debug(f"Selected LoRA {uid} for eviction (FIFO)")
|
||||||
|
self.eviction_count += 1
|
||||||
|
return uid
|
||||||
|
|
||||||
|
# Should never reach here if candidates is non-empty
|
||||||
|
assert False, f"Failed to select FIFO victim from candidates: {candidates}"
|
||||||
|
|
||||||
|
def remove(self, uid: Optional[str]) -> None:
|
||||||
|
if uid is not None:
|
||||||
|
self.insertion_order.pop(uid, None)
|
||||||
|
|
||||||
|
|
||||||
|
def get_eviction_policy(policy_name: str) -> EvictionPolicy:
|
||||||
|
"""Factory function to create eviction policy instances."""
|
||||||
|
policies = {
|
||||||
|
"fifo": FIFOEvictionPolicy,
|
||||||
|
"lru": LRUEvictionPolicy,
|
||||||
|
}
|
||||||
|
if policy_name not in policies:
|
||||||
|
raise ValueError(f"Unknown eviction policy: {policy_name}")
|
||||||
|
return policies[policy_name]()
|
||||||
@@ -68,6 +68,9 @@ class LoRAManager:
|
|||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
|
|
||||||
|
# Store eviction policy from server args
|
||||||
|
self.eviction_policy = server_args.lora_eviction_policy
|
||||||
|
|
||||||
# LoRA backend for running sgemm kernels
|
# LoRA backend for running sgemm kernels
|
||||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||||
backend_type = get_backend_from_name(lora_backend)
|
backend_type = get_backend_from_name(lora_backend)
|
||||||
@@ -131,6 +134,16 @@ class LoRAManager:
|
|||||||
lora_ref.lora_id not in self.loras
|
lora_ref.lora_id not in self.loras
|
||||||
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
||||||
|
|
||||||
|
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
|
||||||
|
return self.create_lora_update_result(
|
||||||
|
success=False,
|
||||||
|
error_message=(
|
||||||
|
f"Already have {self.num_pinned_loras} pinned adapters, "
|
||||||
|
f"max allowed is {self.max_loras_per_batch - 1} (reserving 1 slot for dynamic use). "
|
||||||
|
f"Please unpin some adapters or increase max_loras_per_batch."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# load configs
|
# load configs
|
||||||
new_adapter = LoRAConfig(lora_ref.lora_path)
|
new_adapter = LoRAConfig(lora_ref.lora_path)
|
||||||
@@ -420,6 +433,7 @@ class LoRAManager:
|
|||||||
max_lora_rank=self.max_lora_rank,
|
max_lora_rank=self.max_lora_rank,
|
||||||
target_modules=self.target_modules,
|
target_modules=self.target_modules,
|
||||||
base_model=self.base_model,
|
base_model=self.base_model,
|
||||||
|
eviction_policy=self.eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_lora_module(self, module_name, module):
|
def set_lora_module(self, module_name, module):
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.distributed import divide
|
from sglang.srt.distributed import divide
|
||||||
|
from sglang.srt.lora.eviction_policy import get_eviction_policy
|
||||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||||
from sglang.srt.lora.lora import LoRAAdapter
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
@@ -54,6 +55,7 @@ class LoRAMemoryPool:
|
|||||||
max_lora_rank: int,
|
max_lora_rank: int,
|
||||||
target_modules: Set[str],
|
target_modules: Set[str],
|
||||||
base_model: torch.nn.Module,
|
base_model: torch.nn.Module,
|
||||||
|
eviction_policy: str,
|
||||||
):
|
):
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||||
@@ -64,6 +66,9 @@ class LoRAMemoryPool:
|
|||||||
self.max_lora_rank: int = max_lora_rank
|
self.max_lora_rank: int = max_lora_rank
|
||||||
self.target_modules: Set[str] = target_modules
|
self.target_modules: Set[str] = target_modules
|
||||||
|
|
||||||
|
# Initialize eviction policy
|
||||||
|
self.eviction_policy = get_eviction_policy(eviction_policy)
|
||||||
|
|
||||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
# A_buffer contains num_layer number of row-major tensors with shape
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
# (max_loras_per_batch, stacked_num * max_lora_dim, input_dim)
|
||||||
@@ -189,31 +194,50 @@ class LoRAMemoryPool:
|
|||||||
lora_refs: Dict[str, LoRARef],
|
lora_refs: Dict[str, LoRARef],
|
||||||
):
|
):
|
||||||
def get_available_buffer_slot():
|
def get_available_buffer_slot():
|
||||||
|
# 1. Prioritize empty slots
|
||||||
for buffer_id in range(self.max_loras_per_batch):
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
# Prioritize empty slots
|
|
||||||
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
|
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
|
||||||
return buffer_id
|
return buffer_id
|
||||||
|
|
||||||
|
# 2. Memory pool is full, need to evict using policy
|
||||||
|
candidates = set()
|
||||||
|
|
||||||
for buffer_id in range(self.max_loras_per_batch):
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
uid = self.buffer_id_to_uid[buffer_id]
|
uid = self.buffer_id_to_uid[buffer_id]
|
||||||
|
|
||||||
# Evict unneeded lora
|
# Skip if this adapter is needed by current batch
|
||||||
if uid not in cur_uids:
|
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
|
||||||
# Skip pinned LoRAs
|
if uid in cur_uids:
|
||||||
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
|
continue
|
||||||
if uid is not None:
|
|
||||||
lora_ref = lora_refs.get(uid)
|
|
||||||
if lora_ref is not None and lora_ref.pinned:
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.uid_to_buffer_id.pop(uid)
|
# Skip if this adapter is pinned (base model cannot be pinned, so can be evicted)
|
||||||
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
|
if uid is not None:
|
||||||
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
|
lora_ref = lora_refs.get(uid)
|
||||||
return buffer_id
|
if lora_ref and lora_ref.pinned:
|
||||||
|
continue
|
||||||
|
candidates.add(uid)
|
||||||
|
|
||||||
raise ValueError(
|
if not candidates:
|
||||||
"No available buffer slots found. Please ensure the number of active loras is less than max_loras_per_batch."
|
raise ValueError(
|
||||||
|
"No available buffer slots found. Please ensure the number of active (pinned) loras is less than max_loras_per_batch."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select victim using eviction policy
|
||||||
|
victim_uid = self.eviction_policy.select_victim(candidates)
|
||||||
|
|
||||||
|
# Evict the selected victim
|
||||||
|
victim_buffer_id = self.uid_to_buffer_id[victim_uid]
|
||||||
|
self.uid_to_buffer_id.pop(victim_uid)
|
||||||
|
self.eviction_policy.remove(victim_uid)
|
||||||
|
self.buffer_id_to_uid[victim_buffer_id] = EMPTY_SLOT
|
||||||
|
logger.debug(
|
||||||
|
f"Evicting LoRA {victim_uid} from buffer slot {victim_buffer_id}."
|
||||||
)
|
)
|
||||||
|
return victim_buffer_id
|
||||||
|
|
||||||
|
# Mark all adapters in current batch as used (for LRU tracking)
|
||||||
|
for uid in cur_uids:
|
||||||
|
self.eviction_policy.mark_used(uid)
|
||||||
|
|
||||||
for uid in cur_uids:
|
for uid in cur_uids:
|
||||||
if uid not in self.uid_to_buffer_id:
|
if uid not in self.uid_to_buffer_id:
|
||||||
|
|||||||
@@ -122,6 +122,8 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
|||||||
|
|
||||||
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
|
||||||
|
|
||||||
|
DEFAULT_LORA_EVICTION_POLICY = "lru"
|
||||||
|
|
||||||
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
|
NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"]
|
||||||
|
|
||||||
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
|
||||||
@@ -304,6 +306,7 @@ class ServerArgs:
|
|||||||
] = None
|
] = None
|
||||||
max_loaded_loras: Optional[int] = None
|
max_loaded_loras: Optional[int] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
|
lora_eviction_policy: str = DEFAULT_LORA_EVICTION_POLICY
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
max_lora_chunk_size: Optional[int] = 16
|
max_lora_chunk_size: Optional[int] = 16
|
||||||
|
|
||||||
@@ -2127,6 +2130,13 @@ class ServerArgs:
|
|||||||
default=ServerArgs.max_loaded_loras,
|
default=ServerArgs.max_loaded_loras,
|
||||||
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
|
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-eviction-policy",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_LORA_EVICTION_POLICY,
|
||||||
|
choices=["lru", "fifo"],
|
||||||
|
help="LoRA adapter eviction policy when memory pool is full. 'lru': Least Recently Used (default, better cache efficiency). 'fifo': First-In-First-Out.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-backend",
|
"--lora-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -519,6 +519,7 @@ class SRTRunner:
|
|||||||
lora_target_modules: Optional[List[str]] = None,
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
enable_lora: Optional[bool] = None,
|
enable_lora: Optional[bool] = None,
|
||||||
max_loaded_loras: Optional[int] = None,
|
max_loaded_loras: Optional[int] = None,
|
||||||
|
lora_eviction_policy: str = "lru",
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -565,6 +566,7 @@ class SRTRunner:
|
|||||||
lora_target_modules=lora_target_modules,
|
lora_target_modules=lora_target_modules,
|
||||||
enable_lora=enable_lora,
|
enable_lora=enable_lora,
|
||||||
max_loaded_loras=max_loaded_loras,
|
max_loaded_loras=max_loaded_loras,
|
||||||
|
lora_eviction_policy=lora_eviction_policy,
|
||||||
**spec_kwargs,
|
**spec_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
190
test/srt/lora/test_lora_eviction_policy.py
Normal file
190
test/srt/lora/test_lora_eviction_policy.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
Unit tests for LoRA eviction policies.
|
||||||
|
Tests LRU and FIFO eviction behavior.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from sglang.srt.lora.eviction_policy import get_eviction_policy
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoRAEvictionPolicy(unittest.TestCase):
|
||||||
|
"""Unit tests for LoRA eviction policies."""
|
||||||
|
|
||||||
|
def _test_eviction_policy(
|
||||||
|
self, policy_name, access_sequence, candidates, expected_victim
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper to test eviction policy with given access pattern.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_name: Name of eviction policy ("lru" or "fifo")
|
||||||
|
access_sequence: List of adapter IDs in access order
|
||||||
|
candidates: Set of adapter IDs that can be evicted
|
||||||
|
expected_victim: Expected adapter ID to be evicted
|
||||||
|
"""
|
||||||
|
policy = get_eviction_policy(policy_name)
|
||||||
|
|
||||||
|
# Simulate access pattern
|
||||||
|
for adapter_id in access_sequence:
|
||||||
|
policy.mark_used(adapter_id)
|
||||||
|
|
||||||
|
# Select victim from candidates
|
||||||
|
victim = policy.select_victim(candidates)
|
||||||
|
self.assertEqual(
|
||||||
|
victim,
|
||||||
|
expected_victim,
|
||||||
|
f"{policy_name.upper()}: Expected {expected_victim}, got {victim}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lru_basic(self):
|
||||||
|
"""Test LRU selects least recently used adapter."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"lru",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora4"],
|
||||||
|
candidates={"lora1", "lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lru_with_reuse(self):
|
||||||
|
"""Test LRU updates order on reuse."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"lru",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora4", "lora1"],
|
||||||
|
candidates={"lora1", "lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora2",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lru_multiple_reuse(self):
|
||||||
|
"""Test LRU with multiple reuses."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"lru",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora1", "lora2"],
|
||||||
|
candidates={"lora1", "lora2", "lora3"},
|
||||||
|
expected_victim="lora3",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lru_with_subset_candidates(self):
|
||||||
|
"""Test LRU with subset of candidates."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"lru",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora4"],
|
||||||
|
candidates={"lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora2",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_lru_base_model_priority(self):
|
||||||
|
"""Test LRU prioritizes base model for eviction."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"lru",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3"],
|
||||||
|
candidates={None, "lora1", "lora2", "lora3"},
|
||||||
|
expected_victim=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fifo_basic(self):
|
||||||
|
"""Test FIFO selects first inserted adapter."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"fifo",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora4"],
|
||||||
|
candidates={"lora1", "lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fifo_ignores_reuse(self):
|
||||||
|
"""Test FIFO ignores reuse."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"fifo",
|
||||||
|
access_sequence=[
|
||||||
|
"lora1",
|
||||||
|
"lora2",
|
||||||
|
"lora3",
|
||||||
|
"lora4",
|
||||||
|
"lora4",
|
||||||
|
"lora3",
|
||||||
|
"lora2",
|
||||||
|
"lora1",
|
||||||
|
],
|
||||||
|
candidates={"lora1", "lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora1",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fifo_with_subset_candidates(self):
|
||||||
|
"""Test FIFO with subset of candidates."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"fifo",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3", "lora4"],
|
||||||
|
candidates={"lora2", "lora3", "lora4"},
|
||||||
|
expected_victim="lora2",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_fifo_base_model_priority(self):
|
||||||
|
"""Test FIFO prioritizes base model for eviction."""
|
||||||
|
self._test_eviction_policy(
|
||||||
|
"fifo",
|
||||||
|
access_sequence=["lora1", "lora2", "lora3"],
|
||||||
|
candidates={None, "lora1", "lora2", "lora3"},
|
||||||
|
expected_victim=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_policy_remove(self):
|
||||||
|
"""Test that remove() correctly updates internal state."""
|
||||||
|
lru = get_eviction_policy("lru")
|
||||||
|
lru.mark_used("lora1")
|
||||||
|
lru.mark_used("lora2")
|
||||||
|
lru.mark_used("lora3")
|
||||||
|
|
||||||
|
# Remove lora1, so lora2 becomes LRU
|
||||||
|
lru.remove("lora1")
|
||||||
|
victim = lru.select_victim({"lora1", "lora2", "lora3"})
|
||||||
|
self.assertEqual(victim, "lora2")
|
||||||
|
|
||||||
|
def test_eviction_policy_factory(self):
|
||||||
|
"""Test eviction policy factory function."""
|
||||||
|
# Test valid policies
|
||||||
|
lru = get_eviction_policy("lru")
|
||||||
|
fifo = get_eviction_policy("fifo")
|
||||||
|
|
||||||
|
self.assertIsNotNone(lru)
|
||||||
|
self.assertIsNotNone(fifo)
|
||||||
|
|
||||||
|
# Test invalid policy
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
get_eviction_policy("invalid_policy")
|
||||||
|
|
||||||
|
def test_lru_vs_fifo_behavior(self):
|
||||||
|
"""Test that LRU and FIFO behave differently."""
|
||||||
|
access_sequence = ["lora1", "lora2", "lora3", "lora1"]
|
||||||
|
candidates = {"lora1", "lora2", "lora3"}
|
||||||
|
|
||||||
|
lru = get_eviction_policy("lru")
|
||||||
|
for adapter_id in access_sequence:
|
||||||
|
lru.mark_used(adapter_id)
|
||||||
|
lru_victim = lru.select_victim(candidates)
|
||||||
|
|
||||||
|
fifo = get_eviction_policy("fifo")
|
||||||
|
for adapter_id in access_sequence:
|
||||||
|
fifo.mark_used(adapter_id)
|
||||||
|
fifo_victim = fifo.select_victim(candidates)
|
||||||
|
|
||||||
|
self.assertNotEqual(lru_victim, fifo_victim)
|
||||||
|
self.assertEqual(lru_victim, "lora2")
|
||||||
|
self.assertEqual(fifo_victim, "lora1")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
||||||
@@ -20,6 +20,8 @@ suites = {
|
|||||||
TestFile("hicache/test_hicache_mla.py", 127),
|
TestFile("hicache/test_hicache_mla.py", 127),
|
||||||
TestFile("hicache/test_hicache_storage.py", 127),
|
TestFile("hicache/test_hicache_storage.py", 127),
|
||||||
TestFile("lora/test_lora.py", 200),
|
TestFile("lora/test_lora.py", 200),
|
||||||
|
TestFile("lora/test_lora_eviction.py", 200),
|
||||||
|
TestFile("lora/test_lora_eviction_policy.py", 200),
|
||||||
TestFile("lora/test_lora_backend.py", 99),
|
TestFile("lora/test_lora_backend.py", 99),
|
||||||
TestFile("lora/test_lora_eviction.py", 200),
|
TestFile("lora/test_lora_eviction.py", 200),
|
||||||
TestFile("lora/test_lora_qwen3.py", 97),
|
TestFile("lora/test_lora_qwen3.py", 97),
|
||||||
|
|||||||
Reference in New Issue
Block a user