Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)
This commit is contained in:
@@ -16,7 +16,7 @@
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Iterable, Optional, Set, Tuple
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
|
|||||||
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
||||||
from sglang.srt.lora.lora import LoRAAdapter
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
from sglang.srt.lora.lora_config import LoRAConfig
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
||||||
from sglang.srt.lora.utils import (
|
from sglang.srt.lora.utils import (
|
||||||
LoRABatchInfo,
|
LoRABatchInfo,
|
||||||
@@ -55,6 +56,7 @@ class LoRAManager:
|
|||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
max_lora_rank: Optional[int] = None,
|
max_lora_rank: Optional[int] = None,
|
||||||
target_modules: Optional[Iterable[str]] = None,
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
|
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||||
):
|
):
|
||||||
self.base_model: torch.nn.Module = base_model
|
self.base_model: torch.nn.Module = base_model
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
@@ -64,10 +66,6 @@ class LoRAManager:
|
|||||||
self.device: torch.device = next(self.base_model.parameters()).device
|
self.device: torch.device = next(self.base_model.parameters()).device
|
||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
self.max_lora_rank: Optional[int] = max_lora_rank
|
|
||||||
self.target_modules: Optional[Set[str]] = (
|
|
||||||
set(target_modules) if target_modules else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# LoRA backend for running sgemm kernels
|
# LoRA backend for running sgemm kernels
|
||||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||||
@@ -75,7 +73,11 @@ class LoRAManager:
|
|||||||
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
||||||
|
|
||||||
# Initialize mutable internal state of the LoRAManager.
|
# Initialize mutable internal state of the LoRAManager.
|
||||||
self.init_state()
|
self.init_state(
|
||||||
|
max_lora_rank=max_lora_rank,
|
||||||
|
target_modules=target_modules,
|
||||||
|
lora_paths=lora_paths,
|
||||||
|
)
|
||||||
|
|
||||||
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
||||||
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
||||||
@@ -112,108 +114,87 @@ class LoRAManager:
|
|||||||
success=success,
|
success=success,
|
||||||
error_message=error_message,
|
error_message=error_message,
|
||||||
loaded_adapters={
|
loaded_adapters={
|
||||||
name: config.path for name, config in self.configs.items()
|
lora_ref.lora_name: lora_ref.lora_path
|
||||||
|
for lora_ref in self.lora_refs.values()
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||||
"""
|
|
||||||
Load LoRA adapters from the specified paths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
|
||||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for lora_name, lora_path in lora_paths.items():
|
|
||||||
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
self.update_state_from_configs()
|
|
||||||
|
|
||||||
return self.create_lora_update_result(
|
|
||||||
success=all(result.success for result in results),
|
|
||||||
error_message="\n".join(
|
|
||||||
result.error_message for result in results if not result.success
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_lora_adapter(
|
|
||||||
self, lora_name: str, lora_path: str, update_state: bool = True
|
|
||||||
) -> LoRAUpdateResult:
|
|
||||||
"""
|
"""
|
||||||
Load a single LoRA adapter from the specified path.
|
Load a single LoRA adapter from the specified path.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lora_name (str): The name of the LoRA adapter.
|
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
|
||||||
lora_path (str): The file path to the LoRA adapter.
|
|
||||||
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
|
||||||
"""
|
"""
|
||||||
|
assert (
|
||||||
success = True
|
lora_ref.lora_name is not None and lora_ref.lora_path is not None
|
||||||
error_message = ""
|
), "LoRARef must have both lora_name and lora_path set for loading."
|
||||||
|
assert (
|
||||||
if lora_name in self.loras:
|
lora_ref.lora_id not in self.loras
|
||||||
success = False
|
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
||||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_adapter = LoRAConfig(lora_path)
|
# load configs
|
||||||
self.validate_new_adapter(lora_name, new_adapter)
|
new_adapter = LoRAConfig(lora_ref.lora_path)
|
||||||
self.configs[lora_name] = new_adapter
|
self.validate_new_adapter(new_adapter, lora_ref)
|
||||||
|
self.configs[lora_ref.lora_id] = new_adapter
|
||||||
|
|
||||||
|
# load weights
|
||||||
|
self.load_lora_weights(lora_ref)
|
||||||
|
|
||||||
|
# keep metadata for displayed messages
|
||||||
|
self.lora_refs[lora_ref.lora_id] = lora_ref
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
success = False
|
return self.create_lora_update_result(
|
||||||
error_message = (
|
success=False,
|
||||||
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
error_message=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
if update_state:
|
return self.create_lora_update_result(success=True)
|
||||||
self.update_state_from_configs()
|
|
||||||
|
|
||||||
return self.create_lora_update_result(
|
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
|
||||||
success=success,
|
|
||||||
error_message=error_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
|
||||||
"""
|
"""
|
||||||
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
memory_pool = getattr(self, "memory_pool", None)
|
||||||
lora_config
|
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
||||||
)
|
|
||||||
if incompatible:
|
if incompatible:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
||||||
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
||||||
"included in `--enable_lora_modules`."
|
"included in `--enable_lora_modules`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
||||||
"""
|
"""
|
||||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||||
delete the corresponding LoRA modules.
|
delete the corresponding LoRA modules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
success = True
|
adapter = self.configs.get(lora_ref.lora_id, None)
|
||||||
error_message = ""
|
assert (
|
||||||
if lora_name in self.loras:
|
adapter is not None
|
||||||
del self.configs[lora_name]
|
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
||||||
else:
|
|
||||||
error_message = f"LoRA adapter {lora_name} is not loaded."
|
|
||||||
success = False
|
|
||||||
|
|
||||||
self.update_state_from_configs()
|
try:
|
||||||
|
del self.configs[lora_ref.lora_id]
|
||||||
|
del self.loras[lora_ref.lora_id]
|
||||||
|
del self.lora_refs[lora_ref.lora_id]
|
||||||
|
except Exception as e:
|
||||||
|
return self.create_lora_update_result(
|
||||||
|
success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
return self.create_lora_update_result(
|
return self.create_lora_update_result(success=True)
|
||||||
success=success,
|
|
||||||
error_message=error_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
# load active loras into lora memory pool
|
# Load active loras into lora memory pool
|
||||||
|
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
||||||
|
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
||||||
|
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
||||||
|
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
||||||
cur_uids = set(forward_batch.lora_paths)
|
cur_uids = set(forward_batch.lora_paths)
|
||||||
assert len(cur_uids) <= self.max_loras_per_batch
|
assert len(cur_uids) <= self.max_loras_per_batch
|
||||||
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
||||||
@@ -233,10 +214,10 @@ class LoRAManager:
|
|||||||
weight_indices = [0] * len(forward_batch.lora_paths)
|
weight_indices = [0] * len(forward_batch.lora_paths)
|
||||||
lora_ranks = [0] * self.max_loras_per_batch
|
lora_ranks = [0] * self.max_loras_per_batch
|
||||||
scalings = [0] * self.max_loras_per_batch
|
scalings = [0] * self.max_loras_per_batch
|
||||||
for i, lora_path in enumerate(forward_batch.lora_paths):
|
for i, uid in enumerate(forward_batch.lora_paths):
|
||||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
||||||
if lora_path is not None:
|
if uid is not None:
|
||||||
lora = self.loras[lora_path]
|
lora = self.loras[uid]
|
||||||
lora_ranks[weight_indices[i]] = lora.config.r
|
lora_ranks[weight_indices[i]] = lora.config.r
|
||||||
scalings[weight_indices[i]] = lora.scaling
|
scalings[weight_indices[i]] = lora.scaling
|
||||||
|
|
||||||
@@ -326,7 +307,7 @@ class LoRAManager:
|
|||||||
"""
|
"""
|
||||||
Update all LoRA modules to associate them with the latest memory buffer.
|
Update all LoRA modules to associate them with the latest memory buffer.
|
||||||
"""
|
"""
|
||||||
for layer_id, layer_modules in self.lora_modules.items():
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
||||||
for module_name, module in layer_modules.items():
|
for module_name, module in layer_modules.items():
|
||||||
if "qkv_proj" in module_name:
|
if "qkv_proj" in module_name:
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
@@ -353,115 +334,94 @@ class LoRAManager:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_state(self):
|
def init_state(
|
||||||
|
self,
|
||||||
|
max_lora_rank: Optional[int] = None,
|
||||||
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
|
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the internal (mutable) state of the LoRAManager.
|
Initialize the internal (mutable) state of the LoRAManager.
|
||||||
|
|
||||||
These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
|
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
|
||||||
|
the target modules and max_lora_rank.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Configs of all active LoRA adapters.
|
assert lora_paths or (
|
||||||
|
max_lora_rank is not None and target_modules is not None
|
||||||
|
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
||||||
|
|
||||||
|
self.init_lora_adapters(lora_paths)
|
||||||
|
self.init_lora_shapes(
|
||||||
|
max_lora_rank=max_lora_rank,
|
||||||
|
target_modules=target_modules,
|
||||||
|
)
|
||||||
|
self.init_lora_weight_names()
|
||||||
|
self.init_lora_modules()
|
||||||
|
self.init_memory_pool()
|
||||||
|
|
||||||
|
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||||
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||||
self.configs: Dict[str, LoRAConfig] = {}
|
self.configs: Dict[str, LoRAConfig] = {}
|
||||||
|
|
||||||
# LoRA adapter weights cached in CPU memory.
|
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
|
||||||
self.loras: Dict[str, LoRAAdapter] = {}
|
self.loras: Dict[str, LoRAAdapter] = {}
|
||||||
|
|
||||||
# Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
|
# Mapping from LoRA ID to LoRARef object.
|
||||||
self.lora_weight_names: Tuple[Set[str]] = (set(), set())
|
self.lora_refs: Dict[str, LoRARef] = {}
|
||||||
|
|
||||||
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
if lora_paths:
|
||||||
self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
|
for lora_ref in lora_paths.values():
|
||||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
result = self.load_lora_adapter(lora_ref)
|
||||||
}
|
if not result.success:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
|
||||||
|
)
|
||||||
|
|
||||||
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
def init_lora_shapes(
|
||||||
# It is initialized lazily when the first LoRA adapter is loaded.
|
self,
|
||||||
self.memory_pool: Optional[LoRAMemoryPool] = None
|
max_lora_rank: Optional[int] = None,
|
||||||
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
|
):
|
||||||
|
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
|
||||||
|
|
||||||
def update_state_from_configs(self):
|
if target_modules is not None:
|
||||||
"""
|
self.target_modules = set(target_modules)
|
||||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
|
||||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Loads / unloads LoRA adapters based on the latest configs.
|
|
||||||
self.update_lora_adapters()
|
|
||||||
# Apply the latest LoRA configurations to the internal state for inferencing.
|
|
||||||
self.apply_lora_configs()
|
|
||||||
|
|
||||||
def apply_lora_configs(self):
|
|
||||||
"""
|
|
||||||
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
|
||||||
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
|
||||||
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
|
||||||
early CY25H2.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if self.memory_pool is None:
|
|
||||||
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
|
||||||
if self.target_modules is None:
|
|
||||||
self.target_modules = set()
|
|
||||||
for config in self.configs.values():
|
|
||||||
self.target_modules.update(config.target_modules)
|
|
||||||
|
|
||||||
if self.max_lora_rank is None:
|
|
||||||
self.max_lora_rank = max(
|
|
||||||
[x.hf_config["r"] for x in self.configs.values()],
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.update_lora_weight_names()
|
|
||||||
self.update_lora_modules()
|
|
||||||
self.update_memory_buffers()
|
|
||||||
else:
|
else:
|
||||||
# No-op if the memory pool can support the current LoRA configurations.
|
self.target_modules = set()
|
||||||
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
for config in self.configs.values():
|
||||||
# module is changed once FlashInfer backend is deprecated.
|
self.target_modules.update(config.target_modules)
|
||||||
assert self.memory_pool.can_support(self.configs.values()), (
|
|
||||||
"LoRA memory pool cannot support the current LoRA configuration. "
|
if max_lora_rank is not None:
|
||||||
"This should never happen as we should have validated adapter compatibility. "
|
self.max_lora_rank = max_lora_rank
|
||||||
"Please create a Github issue to report.",
|
else:
|
||||||
|
self.max_lora_rank = max(
|
||||||
|
[x.hf_config["r"] for x in self.configs.values()],
|
||||||
|
default=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_lora_weight_names(self):
|
def init_lora_weight_names(self):
|
||||||
"""
|
"""
|
||||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||||
self.lora_weight_names[0].update(lora_A)
|
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
||||||
self.lora_weight_names[1].update(lora_B)
|
|
||||||
|
|
||||||
def update_lora_adapters(self):
|
def load_lora_weights(self, lora_ref: LoRARef):
|
||||||
"""
|
"""
|
||||||
Update the LoRA adapters in CPU memory based on the current `self.configs`.
|
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
||||||
It loads any new adapters that are not already loaded, and unloads any adapters
|
|
||||||
that are no longer in `self.configs` (e.g., unloaded).
|
|
||||||
"""
|
"""
|
||||||
|
lora_adapter = LoRAAdapter(
|
||||||
# Load new adapter weights to cpu
|
lora_ref.lora_id,
|
||||||
for name, config in self.configs.items():
|
self.configs[lora_ref.lora_id],
|
||||||
if name not in self.loras:
|
self.base_hf_config,
|
||||||
logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
|
self.load_config,
|
||||||
lora_adapter = LoRAAdapter(
|
self.lora_backend,
|
||||||
name,
|
)
|
||||||
config,
|
lora_adapter.initialize_weights()
|
||||||
self.base_hf_config,
|
self.loras[lora_ref.lora_id] = lora_adapter
|
||||||
self.load_config,
|
|
||||||
self.lora_backend,
|
|
||||||
)
|
|
||||||
lora_adapter.initialize_weights()
|
|
||||||
self.loras[name] = lora_adapter
|
|
||||||
|
|
||||||
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
|
||||||
for name in list(self.loras):
|
|
||||||
if name not in self.configs:
|
|
||||||
logger.info(f"Unloading LoRA adapter {name}")
|
|
||||||
del self.loras[name]
|
|
||||||
|
|
||||||
# Additional checks for flashinfer backend
|
# Additional checks for flashinfer backend
|
||||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||||
@@ -472,7 +432,7 @@ class LoRAManager:
|
|||||||
len(lora_dims) == 1 and len(scalings) == 1
|
len(lora_dims) == 1 and len(scalings) == 1
|
||||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||||
|
|
||||||
def update_memory_buffers(self):
|
def init_memory_pool(self):
|
||||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||||
self.memory_pool = LoRAMemoryPool(
|
self.memory_pool = LoRAMemoryPool(
|
||||||
base_hf_config=self.base_hf_config,
|
base_hf_config=self.base_hf_config,
|
||||||
@@ -490,7 +450,12 @@ class LoRAManager:
|
|||||||
replace_submodule(self.base_model, module_name, lora_module)
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
return lora_module
|
return lora_module
|
||||||
|
|
||||||
def update_lora_modules(self):
|
def init_lora_modules(self):
|
||||||
|
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
||||||
|
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
||||||
|
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
||||||
|
]
|
||||||
|
|
||||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||||
# e.g., {"qkv_proj", "o_proj"}
|
# e.g., {"qkv_proj", "o_proj"}
|
||||||
customized_target_names = get_customized_names_from_hf_names(
|
customized_target_names = get_customized_names_from_hf_names(
|
||||||
@@ -511,7 +476,6 @@ class LoRAManager:
|
|||||||
# The module should be converted if it is included in target_names
|
# The module should be converted if it is included in target_names
|
||||||
if module_name.split(".")[-1] in customized_target_names:
|
if module_name.split(".")[-1] in customized_target_names:
|
||||||
layer_id = get_layer_id(module_name)
|
layer_id = get_layer_id(module_name)
|
||||||
if module_name not in self.lora_modules[layer_id]:
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
||||||
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
module_name, module
|
||||||
module_name, module
|
)
|
||||||
)
|
|
||||||
|
|||||||
124
python/sglang/srt/lora/lora_registry.py
Normal file
124
python/sglang/srt/lora/lora_registry.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class LoRARef:
|
||||||
|
"""
|
||||||
|
Reference record for a LoRA model.
|
||||||
|
|
||||||
|
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
|
||||||
|
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
||||||
|
keys (e.g., radix cache).
|
||||||
|
"""
|
||||||
|
|
||||||
|
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
||||||
|
lora_name: Optional[str] = None
|
||||||
|
lora_path: Optional[str] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.lora_id is None:
|
||||||
|
raise ValueError("lora_id cannot be None")
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
parts = [
|
||||||
|
f"{f.name}={value}"
|
||||||
|
for f in fields(self)
|
||||||
|
if (value := getattr(self, f.name)) is not None
|
||||||
|
]
|
||||||
|
return f"{self.__class__.__name__}({', '.join(parts)})"
|
||||||
|
|
||||||
|
|
||||||
|
class LoRARegistry:
|
||||||
|
"""
|
||||||
|
The central registry to keep track of available LoRA adapters.
|
||||||
|
|
||||||
|
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
||||||
|
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
||||||
|
assert lora_paths is None or all(
|
||||||
|
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
||||||
|
), (
|
||||||
|
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
||||||
|
"Please file an issue if you see this error."
|
||||||
|
)
|
||||||
|
|
||||||
|
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
||||||
|
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
||||||
|
|
||||||
|
async def register(self, lora_ref: LoRARef):
|
||||||
|
"""
|
||||||
|
Register a new LoRARef object in the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_ref (LoRARef): The LoRARef object to register.
|
||||||
|
"""
|
||||||
|
if lora_ref.lora_name in self._registry:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
||||||
|
)
|
||||||
|
self._registry[lora_ref.lora_name] = lora_ref
|
||||||
|
|
||||||
|
async def unregister(self, lora_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_name (str): The name of the LoRA model to unregister.
|
||||||
|
"""
|
||||||
|
lora_ref = self._registry.get(lora_name, None)
|
||||||
|
if lora_ref is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
||||||
|
)
|
||||||
|
del self._registry[lora_name]
|
||||||
|
|
||||||
|
return lora_ref.lora_id
|
||||||
|
|
||||||
|
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
||||||
|
by incrementing its counter.
|
||||||
|
|
||||||
|
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _acquire_single(name: str) -> str:
|
||||||
|
lora_ref = self._registry.get(name, None)
|
||||||
|
if lora_ref is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following requested LoRA adapters are not loaded: {name}\n"
|
||||||
|
f"Loaded adapters: {self._registry.keys()}."
|
||||||
|
)
|
||||||
|
# await self._counters[lora_ref.lora_id].increment()
|
||||||
|
return lora_ref.lora_id
|
||||||
|
|
||||||
|
if isinstance(lora_name, str):
|
||||||
|
lora_id = await _acquire_single(lora_name)
|
||||||
|
return lora_id
|
||||||
|
elif isinstance(lora_name, list):
|
||||||
|
lora_ids = await asyncio.gather(
|
||||||
|
*[_acquire_single(name) for name in lora_name]
|
||||||
|
)
|
||||||
|
return lora_ids
|
||||||
|
else:
|
||||||
|
raise TypeError("lora_name must be either a string or a list of strings.")
|
||||||
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
|
|||||||
self,
|
self,
|
||||||
cur_uids: Set[Optional[str]],
|
cur_uids: Set[Optional[str]],
|
||||||
lora_adapters: Dict[str, LoRAAdapter],
|
lora_adapters: Dict[str, LoRAAdapter],
|
||||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
||||||
):
|
):
|
||||||
def get_available_buffer_slot():
|
def get_available_buffer_slot():
|
||||||
for buffer_id in range(self.max_loras_per_batch):
|
for buffer_id in range(self.max_loras_per_batch):
|
||||||
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
|
|||||||
uid: str,
|
uid: str,
|
||||||
buffer_id: int,
|
buffer_id: int,
|
||||||
lora_adapter: LoRAAdapter,
|
lora_adapter: LoRAAdapter,
|
||||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
||||||
):
|
):
|
||||||
def load_lora_weight_tensor(
|
def load_lora_weight_tensor(
|
||||||
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from dataclasses import dataclass, field
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
||||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
@@ -1067,19 +1068,36 @@ class LoadLoRAAdapterReqInput:
|
|||||||
lora_name: str
|
lora_name: str
|
||||||
# The path of loading.
|
# The path of loading.
|
||||||
lora_path: str
|
lora_path: str
|
||||||
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||||
|
lora_id: Optional[str] = None
|
||||||
|
|
||||||
|
def to_ref(self) -> LoRARef:
|
||||||
|
return LoRARef(
|
||||||
|
lora_id=self.lora_id,
|
||||||
|
lora_name=self.lora_name,
|
||||||
|
lora_path=self.lora_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UnloadLoRAAdapterReqInput:
|
class UnloadLoRAAdapterReqInput:
|
||||||
# The name of lora module to unload.
|
# The name of lora module to unload.
|
||||||
lora_name: str
|
lora_name: str
|
||||||
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
||||||
|
lora_id: Optional[str] = None
|
||||||
|
|
||||||
|
def to_ref(self) -> LoRARef:
|
||||||
|
return LoRARef(
|
||||||
|
lora_id=self.lora_id,
|
||||||
|
lora_name=self.lora_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoRAUpdateResult:
|
class LoRAUpdateResult:
|
||||||
success: bool
|
success: bool
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|||||||
@@ -247,7 +247,7 @@ class Scheduler(
|
|||||||
self.pp_size = server_args.pp_size
|
self.pp_size = server_args.pp_size
|
||||||
self.dp_size = server_args.dp_size
|
self.dp_size = server_args.dp_size
|
||||||
self.schedule_policy = server_args.schedule_policy
|
self.schedule_policy = server_args.schedule_policy
|
||||||
self.lora_paths = server_args.lora_paths
|
self.enable_lora = server_args.enable_lora
|
||||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||||
self.enable_overlap = not server_args.disable_overlap_schedule
|
self.enable_overlap = not server_args.disable_overlap_schedule
|
||||||
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
self.skip_tokenizer_init = server_args.skip_tokenizer_init
|
||||||
@@ -1706,13 +1706,13 @@ class Scheduler(
|
|||||||
self.chunked_req.init_next_round_input()
|
self.chunked_req.init_next_round_input()
|
||||||
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
||||||
|
|
||||||
if self.lora_paths:
|
if self.enable_lora:
|
||||||
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
lora_set = set([req.lora_path for req in self.running_batch.reqs])
|
||||||
|
|
||||||
# Get requests from the waiting queue to a new prefill batch
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if (
|
if (
|
||||||
self.lora_paths
|
self.enable_lora
|
||||||
and len(
|
and len(
|
||||||
lora_set
|
lora_set
|
||||||
| set([req.lora_path for req in adder.can_run_list])
|
| set([req.lora_path for req in adder.can_run_list])
|
||||||
@@ -2466,12 +2466,6 @@ class Scheduler(
|
|||||||
"""In-place loading a new lora adapter from disk or huggingface."""
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
||||||
|
|
||||||
result = self.tp_worker.load_lora_adapter(recv_req)
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
||||||
|
|
||||||
if result.success:
|
|
||||||
flush_cache_success = self.flush_cache()
|
|
||||||
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
|
||||||
else:
|
|
||||||
logger.error(result.error_message)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def unload_lora_adapter(
|
def unload_lora_adapter(
|
||||||
@@ -2480,14 +2474,6 @@ class Scheduler(
|
|||||||
"""Unload the lora adapter."""
|
"""Unload the lora adapter."""
|
||||||
|
|
||||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||||
|
|
||||||
if result.success:
|
|
||||||
flush_cache_success = self.flush_cache()
|
|
||||||
assert (
|
|
||||||
flush_cache_success
|
|
||||||
), "Cache flush failed after unloading LoRA weights"
|
|
||||||
else:
|
|
||||||
logger.error(result.error_message)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ from sglang.srt.hf_transformers_utils import (
|
|||||||
get_tokenizer,
|
get_tokenizer,
|
||||||
get_tokenizer_from_processor,
|
get_tokenizer_from_processor,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
BatchEmbeddingOut,
|
BatchEmbeddingOut,
|
||||||
@@ -242,11 +243,11 @@ class TokenizerManager:
|
|||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
||||||
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||||
self.loaded_lora_adapters: Dict[str, str] = dict(
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||||
self.server_args.lora_paths or {}
|
# to internally used unique LoRA IDs.
|
||||||
)
|
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
||||||
|
|
||||||
# Store states
|
# Store states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
@@ -523,6 +524,10 @@ class TokenizerManager:
|
|||||||
else:
|
else:
|
||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
|
|
||||||
|
if self.server_args.enable_lora and obj.lora_path:
|
||||||
|
# Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
|
||||||
|
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
||||||
|
|
||||||
self._validate_one_request(obj, input_ids)
|
self._validate_one_request(obj, input_ids)
|
||||||
return self._create_tokenized_object(
|
return self._create_tokenized_object(
|
||||||
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
||||||
@@ -574,8 +579,6 @@ class TokenizerManager:
|
|||||||
"The server is not configured to enable custom logit processor. "
|
"The server is not configured to enable custom logit processor. "
|
||||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||||
)
|
)
|
||||||
if self.server_args.enable_lora and obj.lora_path:
|
|
||||||
self._validate_lora_adapters(obj)
|
|
||||||
|
|
||||||
def _validate_input_ids_in_vocab(
|
def _validate_input_ids_in_vocab(
|
||||||
self, input_ids: List[int], vocab_size: int
|
self, input_ids: List[int], vocab_size: int
|
||||||
@@ -689,21 +692,6 @@ class TokenizerManager:
|
|||||||
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _validate_lora_adapters(self, obj: GenerateReqInput):
|
|
||||||
"""Validate that the requested LoRA adapters are loaded."""
|
|
||||||
requested_adapters = (
|
|
||||||
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
|
|
||||||
)
|
|
||||||
loaded_adapters = (
|
|
||||||
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
|
|
||||||
)
|
|
||||||
unloaded_adapters = requested_adapters - loaded_adapters
|
|
||||||
if unloaded_adapters:
|
|
||||||
raise ValueError(
|
|
||||||
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
|
|
||||||
f"Loaded adapters: {loaded_adapters}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _send_one_request(
|
def _send_one_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -1054,8 +1042,18 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
|
# Generate new uniquely identifiable LoRARef object.
|
||||||
|
new_adapter = LoRARef(
|
||||||
|
lora_name=obj.lora_name,
|
||||||
|
lora_path=obj.lora_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the new adapter in the registry.
|
||||||
|
obj.lora_id = new_adapter.lora_id
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
self.loaded_lora_adapters = result.loaded_adapters
|
if result.success:
|
||||||
|
await self.lora_registry.register(new_adapter)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def unload_lora_adapter(
|
async def unload_lora_adapter(
|
||||||
@@ -1069,6 +1067,10 @@ class TokenizerManager:
|
|||||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
obj.lora_name is not None
|
||||||
|
), "lora_name must be provided to unload LoRA adapter"
|
||||||
|
|
||||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||||
# with dp_size > 1.
|
# with dp_size > 1.
|
||||||
assert (
|
assert (
|
||||||
@@ -1080,8 +1082,9 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
|
obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
self.loaded_lora_adapters = result.loaded_adapters
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
@@ -1309,7 +1312,7 @@ class TokenizerManager:
|
|||||||
filename = os.path.join(
|
filename = os.path.join(
|
||||||
self.crash_dump_folder,
|
self.crash_dump_folder,
|
||||||
os.getenv("HOSTNAME", None),
|
os.getenv("HOSTNAME", None),
|
||||||
f'crash_dump_{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}.pkl',
|
f"crash_dump_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.pkl",
|
||||||
)
|
)
|
||||||
|
|
||||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||||
|
|||||||
@@ -293,11 +293,9 @@ class TpModelWorker:
|
|||||||
return parameter
|
return parameter
|
||||||
|
|
||||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||||
result = self.model_runner.load_lora_adapter(
|
result = self.model_runner.load_lora_adapter(recv_req.to_ref())
|
||||||
recv_req.lora_name, recv_req.lora_path
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||||
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
|
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -68,6 +68,7 @@ from sglang.srt.layers.sampler import Sampler
|
|||||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.lora.lora_manager import LoRAManager
|
from sglang.srt.lora.lora_manager import LoRAManager
|
||||||
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
GLOBAL_SERVER_ARGS_KEYS,
|
GLOBAL_SERVER_ARGS_KEYS,
|
||||||
global_server_args_dict,
|
global_server_args_dict,
|
||||||
@@ -890,44 +891,38 @@ class ModelRunner:
|
|||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
max_lora_rank=self.server_args.max_lora_rank,
|
max_lora_rank=self.server_args.max_lora_rank,
|
||||||
target_modules=self.server_args.lora_target_modules,
|
target_modules=self.server_args.lora_target_modules,
|
||||||
|
lora_paths=self.server_args.lora_paths,
|
||||||
)
|
)
|
||||||
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths or {})
|
|
||||||
if result.success:
|
|
||||||
logger.info(
|
|
||||||
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
|
||||||
|
|
||||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
def load_lora_adapter(self, lora_ref: LoRARef):
|
||||||
"""Load a new lora adapter from disk or huggingface."""
|
"""Load a new lora adapter from disk or huggingface."""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
|
f"LoRA adapter loading starts: {lora_ref}. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
|
result = self.lora_manager.load_lora_adapter(lora_ref)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
|
f"LoRA adapter loading completes: {lora_ref}. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_name: str):
|
def unload_lora_adapter(self, lora_ref: LoRARef):
|
||||||
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"LoRA adapter unloading starts: name={lora_name}. "
|
f"LoRA adapter unloading starts: {lora_ref}. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
result = self.lora_manager.unload_lora_adapter(lora_name)
|
result = self.lora_manager.unload_lora_adapter(lora_ref)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"LoRA adapter unloading completes: name={lora_name}. "
|
f"LoRA adapter unloading completes: {lora_ref}. "
|
||||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -20,10 +20,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from token import OP
|
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||||
|
from sglang.srt.lora.lora_registry import LoRARef
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
LORA_TARGET_ALL_MODULES,
|
LORA_TARGET_ALL_MODULES,
|
||||||
@@ -145,7 +145,7 @@ class ServerArgs:
|
|||||||
enable_lora: Optional[bool] = None
|
enable_lora: Optional[bool] = None
|
||||||
max_lora_rank: Optional[int] = None
|
max_lora_rank: Optional[int] = None
|
||||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||||
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
|
|
||||||
@@ -1843,9 +1843,24 @@ class ServerArgs:
|
|||||||
for lora_path in lora_paths:
|
for lora_path in lora_paths:
|
||||||
if "=" in lora_path:
|
if "=" in lora_path:
|
||||||
name, path = lora_path.split("=", 1)
|
name, path = lora_path.split("=", 1)
|
||||||
self.lora_paths[name] = path
|
self.lora_paths[name] = LoRARef(lora_name=name, lora_path=path)
|
||||||
else:
|
else:
|
||||||
self.lora_paths[lora_path] = lora_path
|
self.lora_paths[lora_path] = LoRARef(
|
||||||
|
lora_name=lora_path,
|
||||||
|
lora_path=lora_path,
|
||||||
|
)
|
||||||
|
elif isinstance(self.lora_paths, dict):
|
||||||
|
self.lora_paths = {
|
||||||
|
k: LoRARef(lora_name=k, lora_path=v)
|
||||||
|
for k, v in self.lora_paths.items()
|
||||||
|
}
|
||||||
|
elif self.lora_paths is None:
|
||||||
|
self.lora_paths = {}
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
||||||
|
"Expected a list or a dictionary."
|
||||||
|
)
|
||||||
|
|
||||||
# Expand target modules
|
# Expand target modules
|
||||||
if self.lora_target_modules:
|
if self.lora_target_modules:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
@@ -39,6 +40,16 @@ ADAPTERS = [
|
|||||||
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
|
||||||
|
"""A context manager to load and automatically unload a LoRA adapter."""
|
||||||
|
try:
|
||||||
|
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
runner.unload_lora_adapter(lora_name=lora_name)
|
||||||
|
|
||||||
|
|
||||||
class TestLoRAEviction(CustomTestCase):
|
class TestLoRAEviction(CustomTestCase):
|
||||||
def test_lora_eviction_with_different_target_modules(self):
|
def test_lora_eviction_with_different_target_modules(self):
|
||||||
"""
|
"""
|
||||||
@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
|
|||||||
self._run_test(ADAPTERS, output_history, reverse=False)
|
self._run_test(ADAPTERS, output_history, reverse=False)
|
||||||
self._run_test(ADAPTERS, output_history, reverse=True)
|
self._run_test(ADAPTERS, output_history, reverse=True)
|
||||||
|
|
||||||
|
def test_lora_eviction_with_reused_lora_name(self):
|
||||||
|
"""
|
||||||
|
Test LoRA eviction with reused LoRA names.
|
||||||
|
|
||||||
|
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
|
||||||
|
works correctly when reusing LoRA names.
|
||||||
|
"""
|
||||||
|
output_history = {}
|
||||||
|
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
|
||||||
|
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
|
||||||
|
|
||||||
def _run_test(
|
def _run_test(
|
||||||
self,
|
self,
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
output_history: Dict[Tuple[str, str], str],
|
output_history: Dict[Tuple[str, str], str],
|
||||||
reverse: bool,
|
reverse: bool = False,
|
||||||
repeat: int = 2,
|
repeat: int = 2,
|
||||||
|
reuse_lora_name: bool = False,
|
||||||
):
|
):
|
||||||
|
REUSED_LORA_NAME = "lora"
|
||||||
max_new_tokens = 256
|
max_new_tokens = 256
|
||||||
backend = "triton"
|
backend = "triton"
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
base_path = BASE_MODEL
|
base_path = BASE_MODEL
|
||||||
assert len(lora_paths) >= 2
|
assert len(lora_paths) >= 2
|
||||||
|
|
||||||
|
initial_lora_paths = lora_paths if not reuse_lora_name else None
|
||||||
# Initialize runners
|
# Initialize runners
|
||||||
with SRTRunner(
|
with SRTRunner(
|
||||||
base_path,
|
base_path,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
model_type="generation",
|
model_type="generation",
|
||||||
lora_paths=lora_paths,
|
lora_paths=initial_lora_paths,
|
||||||
max_loras_per_batch=1,
|
max_loras_per_batch=1,
|
||||||
lora_backend=backend,
|
lora_backend=backend,
|
||||||
disable_radix_cache=True,
|
disable_radix_cache=True,
|
||||||
|
enable_lora=True,
|
||||||
|
max_lora_rank=256,
|
||||||
|
lora_target_modules=["all"],
|
||||||
) as srt_runner:
|
) as srt_runner:
|
||||||
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
|
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
|
||||||
|
|
||||||
for i in range(repeat):
|
for i in range(repeat):
|
||||||
for j, adapter in enumerate(adapter_sequence):
|
for j, lora_path in enumerate(adapter_sequence):
|
||||||
print(
|
print(
|
||||||
f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---"
|
f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
|
||||||
)
|
)
|
||||||
for prompt in PROMPTS:
|
|
||||||
print("\nprompt:\n", prompt)
|
|
||||||
srt_outputs = srt_runner.forward(
|
|
||||||
[prompt],
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
lora_paths=[adapter],
|
|
||||||
)
|
|
||||||
output = srt_outputs.output_strs[0].strip()
|
|
||||||
print("\noutput:\n", output)
|
|
||||||
|
|
||||||
prev_output = output_history.get((adapter, prompt))
|
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
|
||||||
if prev_output is not None:
|
context = (
|
||||||
self.assertEqual(
|
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
|
||||||
prev_output,
|
if reuse_lora_name
|
||||||
output,
|
else contextlib.nullcontext()
|
||||||
f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
|
)
|
||||||
|
with context:
|
||||||
|
for prompt in PROMPTS:
|
||||||
|
print("\nprompt:\n", prompt)
|
||||||
|
srt_outputs = srt_runner.forward(
|
||||||
|
[prompt],
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
lora_paths=[lora_name],
|
||||||
)
|
)
|
||||||
else:
|
output = srt_outputs.output_strs[0].strip()
|
||||||
output_history[(adapter, prompt)] = output
|
print("\noutput:\n", output)
|
||||||
|
|
||||||
|
prev_output = output_history.get((lora_path, prompt))
|
||||||
|
if prev_output is not None:
|
||||||
|
self.assertEqual(
|
||||||
|
prev_output,
|
||||||
|
output,
|
||||||
|
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output_history[(lora_path, prompt)] = output
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ class TestFile:
|
|||||||
suites = {
|
suites = {
|
||||||
"per-commit": [
|
"per-commit": [
|
||||||
TestFile("models/lora/test_lora.py", 200),
|
TestFile("models/lora/test_lora.py", 200),
|
||||||
TestFile("models/lora/test_lora_eviction.py", 120),
|
TestFile("models/lora/test_lora_eviction.py", 200),
|
||||||
TestFile("models/lora/test_lora_backend.py", 99),
|
TestFile("models/lora/test_lora_backend.py", 99),
|
||||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||||
|
|||||||
Reference in New Issue
Block a user