Support GPU pinning for LoRA (#8697)
This commit is contained in:
@@ -492,12 +492,13 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||
)
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
||||
"""Load a new LoRA adapter without re-launching the engine."""
|
||||
|
||||
obj = LoadLoRAAdapterReqInput(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
pinned=pinned,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -144,6 +144,7 @@ class LoRAManager:
|
||||
|
||||
# keep metadata for displayed messages
|
||||
self.lora_refs[lora_ref.lora_id] = lora_ref
|
||||
self.num_pinned_loras += int(lora_ref.pinned)
|
||||
except Exception as e:
|
||||
return self.create_lora_update_result(
|
||||
success=False,
|
||||
@@ -157,13 +158,22 @@ class LoRAManager:
|
||||
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||
"""
|
||||
|
||||
# Check if the LoRA adapter shape is compatible with the current LoRA memory pool configuration.
|
||||
memory_pool = getattr(self, "memory_pool", None)
|
||||
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
||||
if incompatible:
|
||||
raise ValueError(
|
||||
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
||||
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
||||
"included in `--enable_lora_modules`."
|
||||
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current "
|
||||
"LoRA memory pool configuration. Please ensure that the LoRA adapter's rank is within the configured "
|
||||
"`--max-lora-rank` and that the target modules are included in `--lora-target-modules`."
|
||||
)
|
||||
|
||||
# Ensure pinned LoRA adapters does not exceed maximal limit or cause starvation.
|
||||
if lora_ref.pinned and self.num_pinned_loras >= self.max_loras_per_batch - 1:
|
||||
raise ValueError(
|
||||
f"Failed to load LoRA adapter {lora_ref.lora_name} as a pinned adapter. It is not allowed to pin all slots "
|
||||
"in the LoRA memory pool to avoid starvation for unpinned adapters and base models. Please increase your "
|
||||
"`--max-loras-per-batch` or load it as unpinned LoRA adapters."
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||
@@ -172,15 +182,17 @@ class LoRAManager:
|
||||
delete the corresponding LoRA modules.
|
||||
"""
|
||||
|
||||
adapter = self.configs.get(lora_ref.lora_id, None)
|
||||
adapter = self.configs.get(lora_ref.lora_id)
|
||||
lora_ref = self.lora_refs.get(lora_ref.lora_id)
|
||||
assert (
|
||||
adapter is not None
|
||||
adapter is not None and lora_ref is not None
|
||||
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
||||
|
||||
try:
|
||||
del self.configs[lora_ref.lora_id]
|
||||
del self.loras[lora_ref.lora_id]
|
||||
del self.lora_refs[lora_ref.lora_id]
|
||||
self.num_pinned_loras -= int(lora_ref.pinned)
|
||||
except Exception as e:
|
||||
return self.create_lora_update_result(
|
||||
success=False,
|
||||
@@ -189,11 +201,49 @@ class LoRAManager:
|
||||
|
||||
return self.create_lora_update_result(success=True)
|
||||
|
||||
def validate_lora_batch(self, lora_ids: set[str]) -> bool:
|
||||
"""
|
||||
Validate if the LoRA IDs in the batch can be loaded into the current LoRA memory pool.
|
||||
"""
|
||||
if len(lora_ids) > self.max_loras_per_batch:
|
||||
return False
|
||||
|
||||
# skip pinned LoRA check if no pinned LoRA adapters are loaded.
|
||||
if self.num_pinned_loras == 0:
|
||||
return True
|
||||
|
||||
# counting the number of pinned LoRA adapters in the batch.
|
||||
pinned_loras_in_batch = 0
|
||||
for lora_id in lora_ids:
|
||||
if lora_id is not None:
|
||||
lora_ref = self.lora_refs.get(lora_id)
|
||||
assert (
|
||||
lora_ref is not None
|
||||
), f"LoRA ID {lora_id} not found in lora_refs."
|
||||
pinned_loras_in_batch += int(lora_ref.pinned)
|
||||
|
||||
assert pinned_loras_in_batch <= self.num_pinned_loras, (
|
||||
f"Number of pinned LoRA adapters in the batch ({pinned_loras_in_batch}) exceeds the total number of pinned adapters "
|
||||
f"({self.num_pinned_loras}). This indicates a bug in the LoRA loading logic."
|
||||
)
|
||||
|
||||
required_slots = len(lora_ids) - pinned_loras_in_batch
|
||||
mem_pool_vacancy = self.memory_pool.max_loras_per_batch - self.num_pinned_loras
|
||||
|
||||
return required_slots <= mem_pool_vacancy
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
|
||||
# Load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_ids)
|
||||
|
||||
assert len(cur_uids) <= self.max_loras_per_batch
|
||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||
self.memory_pool.prepare_lora_batch(
|
||||
cur_uids=cur_uids,
|
||||
lora_adapters=self.loras,
|
||||
lora_modules=self.lora_modules,
|
||||
lora_refs=self.lora_refs.copy(), # copy snapshot of current lora_refs to avoid mutation during the batch preparation.
|
||||
)
|
||||
|
||||
# set up batch info shared by all lora modules
|
||||
bs = forward_batch.batch_size
|
||||
@@ -366,6 +416,9 @@ class LoRAManager:
|
||||
# Mapping from LoRA ID to LoRARef object.
|
||||
self.lora_refs: Dict[str, LoRARef] = {}
|
||||
|
||||
# Count of pinned LoRA adapters.
|
||||
self.num_pinned_loras: int = 0
|
||||
|
||||
if lora_paths:
|
||||
for lora_ref in lora_paths.values():
|
||||
result = self.load_lora_adapter(lora_ref)
|
||||
@@ -399,7 +452,7 @@ class LoRAManager:
|
||||
self.max_lora_rank = max_lora_rank
|
||||
else:
|
||||
self.max_lora_rank = max(
|
||||
[x.hf_config["r"] for x in self.configs.values()],
|
||||
[x.r for x in self.configs.values()],
|
||||
default=0,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,14 +28,15 @@ class LoRARef:
|
||||
"""
|
||||
Reference record for a LoRA model.
|
||||
|
||||
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
|
||||
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
||||
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
|
||||
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
||||
keys (e.g., radix cache).
|
||||
"""
|
||||
|
||||
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
||||
lora_name: Optional[str] = None
|
||||
lora_path: Optional[str] = None
|
||||
pinned: Optional[bool] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lora_id is None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.lora_registry import LoRARef
|
||||
from sglang.srt.lora.utils import (
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||
LoRAType,
|
||||
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
|
||||
get_weight_name,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmptySlot:
|
||||
"""
|
||||
Singleton class to represent an empty slot in the memory pool.
|
||||
This is used to improve readability by not using special str as a placeholder.
|
||||
"""
|
||||
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self):
|
||||
return "|EMPTY|"
|
||||
|
||||
def __new__(cls):
|
||||
if not hasattr(cls, "_instance"):
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
|
||||
EMPTY_SLOT = EmptySlot()
|
||||
|
||||
|
||||
class LoRAMemoryPool:
|
||||
"""Class for memory pool management of lora modules"""
|
||||
@@ -54,9 +78,11 @@ class LoRAMemoryPool:
|
||||
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
||||
|
||||
# Buffer idx -> lora uid in memory pool
|
||||
# All uids are initialized as empty strings for empty buffer slots
|
||||
# All uids are initialized as `EmptySlot` for empty buffer slots
|
||||
# Here we don't initialize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
|
||||
EMPTY_SLOT
|
||||
] * self.max_loras_per_batch
|
||||
|
||||
self.init_buffers(base_model)
|
||||
|
||||
@@ -154,17 +180,29 @@ class LoRAMemoryPool:
|
||||
cur_uids: Set[Optional[str]],
|
||||
lora_adapters: Dict[str, LoRAAdapter],
|
||||
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
||||
lora_refs: Dict[str, LoRARef],
|
||||
):
|
||||
def get_available_buffer_slot():
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
# Prioritize empty slots
|
||||
if self.buffer_id_to_uid[buffer_id] == "":
|
||||
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
|
||||
return buffer_id
|
||||
|
||||
for buffer_id in range(self.max_loras_per_batch):
|
||||
uid = self.buffer_id_to_uid[buffer_id]
|
||||
|
||||
# Evict unneeded lora
|
||||
if self.buffer_id_to_uid[buffer_id] not in cur_uids:
|
||||
self.uid_to_buffer_id.pop(self.buffer_id_to_uid[buffer_id])
|
||||
if uid not in cur_uids:
|
||||
# Skip pinned LoRAs
|
||||
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
|
||||
if uid is not None:
|
||||
lora_ref = lora_refs.get(uid)
|
||||
if lora_ref is not None and lora_ref.pinned:
|
||||
continue
|
||||
|
||||
self.uid_to_buffer_id.pop(uid)
|
||||
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
|
||||
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
|
||||
return buffer_id
|
||||
|
||||
raise ValueError(
|
||||
|
||||
@@ -1082,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
|
||||
lora_name: str
|
||||
# The path of loading.
|
||||
lora_path: str
|
||||
# Whether to pin the LoRA adapter in memory.
|
||||
pinned: bool = False
|
||||
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||
lora_id: Optional[str] = None
|
||||
|
||||
@@ -1090,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
|
||||
lora_id=self.lora_id,
|
||||
lora_name=self.lora_name,
|
||||
lora_path=self.lora_path,
|
||||
pinned=self.pinned,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1538,14 +1538,11 @@ class Scheduler(
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
self.enable_lora
|
||||
and len(
|
||||
lora_set
|
||||
| set([req.lora_id for req in adder.can_run_list])
|
||||
| set([req.lora_id])
|
||||
)
|
||||
> self.max_loras_per_batch
|
||||
|
||||
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
|
||||
lora_set
|
||||
| set([req.lora_id for req in adder.can_run_list])
|
||||
| set([req.lora_id])
|
||||
):
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
@@ -1129,6 +1129,7 @@ class TokenizerManager:
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
pinned=obj.pinned,
|
||||
)
|
||||
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
@@ -1186,7 +1187,7 @@ class TokenizerManager:
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
|
||||
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
|
||||
@@ -311,3 +311,6 @@ class TpModelWorker:
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
||||
return result
|
||||
|
||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
||||
|
||||
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
return self.worker.unload_lora_adapter(recv_req)
|
||||
|
||||
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
||||
return self.worker.can_run_lora_batch(lora_ids)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
@@ -2067,21 +2067,23 @@ class ServerArgs:
|
||||
|
||||
if self.enable_lora:
|
||||
# Normalize lora_paths to a dictionary if it is a list.
|
||||
# TODO (lifuhuang): support specifying pinned adapters in server_args.
|
||||
if isinstance(self.lora_paths, list):
|
||||
lora_paths = self.lora_paths
|
||||
self.lora_paths = {}
|
||||
for lora_path in lora_paths:
|
||||
if "=" in lora_path:
|
||||
name, path = lora_path.split("=", 1)
|
||||
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
|
||||
self.lora_paths[name] = LoRARef(
|
||||
lora_name=name, lora_path=path, pinned=False
|
||||
)
|
||||
else:
|
||||
self.lora_paths[lora_path] = LoRARef(
|
||||
lora_name=lora_path,
|
||||
lora_path=lora_path,
|
||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||
)
|
||||
elif isinstance(self.lora_paths, dict):
|
||||
self.lora_paths = {
|
||||
k: LoRARef(lora_name=k, lora_path=v)
|
||||
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||
for k, v in self.lora_paths.items()
|
||||
}
|
||||
elif self.lora_paths is None:
|
||||
|
||||
@@ -568,8 +568,8 @@ class SRTRunner:
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
return self.engine.load_lora_adapter(lora_name, lora_path)
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
||||
return self.engine.load_lora_adapter(lora_name, lora_path, pinned)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
return self.engine.unload_lora_adapter(lora_name)
|
||||
|
||||
Reference in New Issue
Block a user