Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
||||
self.layers: List[LoRALayer] = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
for _ in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
||||
else:
|
||||
self.weights[name] = loaded_weight.cpu()
|
||||
|
||||
# stack kv_proj and gate_up_proj
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
# normalize kv_proj and gate_up_proj
|
||||
for layer in self.layers:
|
||||
weight_names = list(layer.weights.keys())
|
||||
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
||||
get_normalized_lora_weight_names,
|
||||
get_weight_name,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
@@ -98,44 +99,96 @@ class LoRAManager:
|
||||
],
|
||||
)
|
||||
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
||||
def create_lora_update_result(
|
||||
self, success: bool, error_message: str = ""
|
||||
) -> LoRAUpdateResult:
|
||||
return LoRAUpdateResult(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
loaded_adapters={
|
||||
name: config.path for name, config in self.configs.items()
|
||||
},
|
||||
)
|
||||
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load LoRA adapters from the specified paths.
|
||||
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
||||
|
||||
Args:
|
||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for lora_name, lora_path in lora_paths.items():
|
||||
if lora_name in self.loras:
|
||||
logger.warning(
|
||||
f"LoRA adapter {lora_name} is already loaded."
|
||||
"If you want to reload it, please unload it first."
|
||||
)
|
||||
continue
|
||||
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
||||
results.append(result)
|
||||
|
||||
self.update_state_from_configs()
|
||||
|
||||
def unload_lora_adapters(self, lora_names: Set[str]):
|
||||
return self.create_lora_update_result(
|
||||
success=all(result.success for result in results),
|
||||
error_message="\n".join(
|
||||
result.error_message for result in results if not result.success
|
||||
),
|
||||
)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, lora_name: str, lora_path: str, update_state: bool = True
|
||||
) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load a single LoRA adapter from the specified path.
|
||||
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA adapter.
|
||||
lora_path (str): The file path to the LoRA adapter.
|
||||
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
||||
"""
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
|
||||
if lora_name in self.loras:
|
||||
success = False
|
||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||
|
||||
try:
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = (
|
||||
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
||||
)
|
||||
|
||||
if update_state:
|
||||
self.update_state_from_configs()
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
delete the corresponding LoRA modules.
|
||||
|
||||
Args:
|
||||
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
||||
"""
|
||||
for lora_name in lora_names:
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
error_message = f"LoRA adapter {lora_name} is not loaded."
|
||||
success = False
|
||||
|
||||
self.update_state_from_configs()
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
@@ -372,8 +425,8 @@ class LoRAManager:
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# Clean up unused LoRA adapters
|
||||
for name in self.loras:
|
||||
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
||||
for name in list(self.loras):
|
||||
if name not in self.configs:
|
||||
logger.info(f"Unloading LoRA adapter {name}")
|
||||
del self.loras[name]
|
||||
|
||||
Reference in New Issue
Block a user