Support dynamic LoRA loading / unloading in engine/server API (#7446)

This commit is contained in:
Lifu Huang
2025-06-27 21:00:27 -07:00
committed by GitHub
parent cfe2edac38
commit 49538d111b
14 changed files with 949 additions and 31 deletions

View File

@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
self.layers: List[LoRALayer] = nn.ModuleList(
[
LoRALayer(config, base_hf_config)
for i in range(base_hf_config.num_hidden_layers)
for _ in range(base_hf_config.num_hidden_layers)
]
)
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
else:
self.weights[name] = loaded_weight.cpu()
# stack kv_proj and gate_up_proj
for i in range(self.base_hf_config.num_hidden_layers):
layer = self.layers[i]
weight_names = [name for name, _ in layer.weights.items()]
# normalize kv_proj and gate_up_proj
for layer in self.layers:
weight_names = list(layer.weights.keys())
self.normalize_qkv_proj(weight_names, layer.weights)
self.normalize_gate_up_proj(weight_names, layer.weights)

View File

@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
get_normalized_lora_weight_names,
get_weight_name,
)
from sglang.srt.managers.io_struct import LoRAUpdateResult
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import replace_submodule
@@ -98,44 +99,96 @@ class LoRAManager:
],
)
def load_lora_adapters(self, lora_paths: Dict[str, str]):
def create_lora_update_result(
self, success: bool, error_message: str = ""
) -> LoRAUpdateResult:
return LoRAUpdateResult(
success=success,
error_message=error_message,
loaded_adapters={
name: config.path for name, config in self.configs.items()
},
)
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
"""
Load LoRA adapters from the specified paths.
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
Args:
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
If a LoRA adapter is already loaded, it will be skipped with a warning.
"""
results = []
for lora_name, lora_path in lora_paths.items():
if lora_name in self.loras:
logger.warning(
f"LoRA adapter {lora_name} is already loaded."
"If you want to reload it, please unload it first."
)
continue
self.configs[lora_name] = LoRAConfig(lora_path)
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
results.append(result)
self.update_state_from_configs()
def unload_lora_adapters(self, lora_names: Set[str]):
return self.create_lora_update_result(
success=all(result.success for result in results),
error_message="\n".join(
result.error_message for result in results if not result.success
),
)
def load_lora_adapter(
self, lora_name: str, lora_path: str, update_state: bool = True
) -> LoRAUpdateResult:
"""
Load a single LoRA adapter from the specified path.
Args:
lora_name (str): The name of the LoRA adapter.
lora_path (str): The file path to the LoRA adapter.
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
"""
success = True
error_message = ""
if lora_name in self.loras:
success = False
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
try:
self.configs[lora_name] = LoRAConfig(lora_path)
except Exception as e:
success = False
error_message = (
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
)
if update_state:
self.update_state_from_configs()
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
"""
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
delete the corresponding LoRA modules.
Args:
lora_names (Set[str]): A set of LoRA adapter names to unload.
"""
for lora_name in lora_names:
if lora_name in self.loras:
del self.configs[lora_name]
else:
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
success = True
error_message = ""
if lora_name in self.loras:
del self.configs[lora_name]
else:
error_message = f"LoRA adapter {lora_name} is not loaded."
success = False
self.update_state_from_configs()
return self.create_lora_update_result(
success=success,
error_message=error_message,
)
def prepare_lora_batch(self, forward_batch: ForwardBatch):
# load active loras into lora memory pool
cur_uids = set(forward_batch.lora_paths)
@@ -372,8 +425,8 @@ class LoRAManager:
lora_adapter.initialize_weights()
self.loras[name] = lora_adapter
# Clean up unused LoRA adapters
for name in self.loras:
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
for name in list(self.loras):
if name not in self.configs:
logger.info(f"Unloading LoRA adapter {name}")
del self.loras[name]