Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
|
||||
"""Update model weights with in-memory tensor data."""
|
||||
pass
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
"""Load a new LoRA adapter without re-launching the engine."""
|
||||
pass
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""Unload a LoRA adapter without re-launching the engine."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release_memory_occupation(self):
|
||||
"""Release GPU memory occupation temporarily."""
|
||||
|
||||
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
ImageDataItem,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -478,6 +480,29 @@ class Engine(EngineBase):
|
||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||
)
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
"""Load a new LoRA adapter without re-launching the engine."""
|
||||
|
||||
obj = LoadLoRAAdapterReqInput(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.load_lora_adapter(obj, None)
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""Unload a LoRA adapter without re-launching the engine."""
|
||||
|
||||
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.unload_lora_adapter(obj, None)
|
||||
)
|
||||
|
||||
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
OpenSessionReqInput,
|
||||
ParseFunctionCallReq,
|
||||
ProfileReqInput,
|
||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
||||
SeparateReasoningReqInput,
|
||||
SetInternalStateReq,
|
||||
SlowDownReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
|
||||
return _create_error_response(e)
|
||||
|
||||
|
||||
@app.api_route("/load_lora_adapter", methods=["POST"])
|
||||
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
|
||||
"""Load a new LoRA adapter without re-launching the server."""
|
||||
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
|
||||
|
||||
if result.success:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.OK,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/unload_lora_adapter", methods=["POST"])
|
||||
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
|
||||
"""Load a new LoRA adapter without re-launching the server."""
|
||||
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
|
||||
|
||||
if result.success:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.OK,
|
||||
)
|
||||
else:
|
||||
return ORJSONResponse(
|
||||
result,
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||
"""Open a session, and return its unique session id."""
|
||||
|
||||
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
||||
self.layers: List[LoRALayer] = nn.ModuleList(
|
||||
[
|
||||
LoRALayer(config, base_hf_config)
|
||||
for i in range(base_hf_config.num_hidden_layers)
|
||||
for _ in range(base_hf_config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
||||
else:
|
||||
self.weights[name] = loaded_weight.cpu()
|
||||
|
||||
# stack kv_proj and gate_up_proj
|
||||
for i in range(self.base_hf_config.num_hidden_layers):
|
||||
layer = self.layers[i]
|
||||
weight_names = [name for name, _ in layer.weights.items()]
|
||||
# normalize kv_proj and gate_up_proj
|
||||
for layer in self.layers:
|
||||
weight_names = list(layer.weights.keys())
|
||||
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
||||
get_normalized_lora_weight_names,
|
||||
get_weight_name,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import replace_submodule
|
||||
|
||||
@@ -98,44 +99,96 @@ class LoRAManager:
|
||||
],
|
||||
)
|
||||
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
||||
def create_lora_update_result(
|
||||
self, success: bool, error_message: str = ""
|
||||
) -> LoRAUpdateResult:
|
||||
return LoRAUpdateResult(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
loaded_adapters={
|
||||
name: config.path for name, config in self.configs.items()
|
||||
},
|
||||
)
|
||||
|
||||
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load LoRA adapters from the specified paths.
|
||||
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
||||
|
||||
Args:
|
||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||
"""
|
||||
|
||||
results = []
|
||||
for lora_name, lora_path in lora_paths.items():
|
||||
if lora_name in self.loras:
|
||||
logger.warning(
|
||||
f"LoRA adapter {lora_name} is already loaded."
|
||||
"If you want to reload it, please unload it first."
|
||||
)
|
||||
continue
|
||||
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
||||
results.append(result)
|
||||
|
||||
self.update_state_from_configs()
|
||||
|
||||
def unload_lora_adapters(self, lora_names: Set[str]):
|
||||
return self.create_lora_update_result(
|
||||
success=all(result.success for result in results),
|
||||
error_message="\n".join(
|
||||
result.error_message for result in results if not result.success
|
||||
),
|
||||
)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, lora_name: str, lora_path: str, update_state: bool = True
|
||||
) -> LoRAUpdateResult:
|
||||
"""
|
||||
Load a single LoRA adapter from the specified path.
|
||||
|
||||
Args:
|
||||
lora_name (str): The name of the LoRA adapter.
|
||||
lora_path (str): The file path to the LoRA adapter.
|
||||
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
||||
"""
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
|
||||
if lora_name in self.loras:
|
||||
success = False
|
||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||
|
||||
try:
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = (
|
||||
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
||||
)
|
||||
|
||||
if update_state:
|
||||
self.update_state_from_configs()
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
delete the corresponding LoRA modules.
|
||||
|
||||
Args:
|
||||
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
||||
"""
|
||||
for lora_name in lora_names:
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
||||
|
||||
success = True
|
||||
error_message = ""
|
||||
if lora_name in self.loras:
|
||||
del self.configs[lora_name]
|
||||
else:
|
||||
error_message = f"LoRA adapter {lora_name} is not loaded."
|
||||
success = False
|
||||
|
||||
self.update_state_from_configs()
|
||||
|
||||
return self.create_lora_update_result(
|
||||
success=success,
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||
# load active loras into lora memory pool
|
||||
cur_uids = set(forward_batch.lora_paths)
|
||||
@@ -372,8 +425,8 @@ class LoRAManager:
|
||||
lora_adapter.initialize_weights()
|
||||
self.loras[name] = lora_adapter
|
||||
|
||||
# Clean up unused LoRA adapters
|
||||
for name in self.loras:
|
||||
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
||||
for name in list(self.loras):
|
||||
if name not in self.configs:
|
||||
logger.info(f"Unloading LoRA adapter {name}")
|
||||
del self.loras[name]
|
||||
|
||||
@@ -20,7 +20,7 @@ import copy
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||
|
||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||
|
||||
@@ -1002,3 +1002,27 @@ class RpcReqInput:
|
||||
class RpcReqOutput:
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadLoRAAdapterReqInput:
|
||||
# The name of the lora module to newly loaded.
|
||||
lora_name: str
|
||||
# The path of loading.
|
||||
lora_path: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnloadLoRAAdapterReqInput:
|
||||
# The name of lora module to unload.
|
||||
lora_name: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAUpdateResult:
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -519,6 +523,8 @@ class Scheduler(
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -2241,6 +2247,36 @@ class Scheduler(
|
||||
logger.error(message)
|
||||
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
||||
|
||||
def load_lora_adapter(
|
||||
self, recv_req: LoadLoRAAdapterReqInput
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
"""In-place loading a new lora adapter from disk or huggingface."""
|
||||
|
||||
result = self.tp_worker.load_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(
|
||||
self, recv_req: UnloadLoRAAdapterReqInput
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
"""Unload the lora adapter."""
|
||||
|
||||
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||
|
||||
if result.success:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert (
|
||||
flush_cache_success
|
||||
), "Cache flush failed after unloading LoRA weights"
|
||||
else:
|
||||
logger.error(result.error_message)
|
||||
return result
|
||||
|
||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||
"""Initialize the online model parameter update group."""
|
||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||
|
||||
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
|
||||
HealthCheckOutput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
InitWeightsUpdateGroupReqOutput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
LoadLoRAAdapterReqOutput,
|
||||
LoRAUpdateResult,
|
||||
OpenSessionReqInput,
|
||||
OpenSessionReqOutput,
|
||||
ProfileReq,
|
||||
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
|
||||
SlowDownReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
@@ -311,6 +316,9 @@ class TokenizerManager:
|
||||
self.expert_distribution_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_lora_adapter_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -377,6 +385,10 @@ class TokenizerManager:
|
||||
ExpertDistributionReqOutput,
|
||||
self.expert_distribution_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
LoRAUpdateResult,
|
||||
self.update_lora_adapter_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -960,6 +972,49 @@ class TokenizerManager:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
return result
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
obj: UnloadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
return result
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
):
|
||||
|
||||
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -275,3 +277,13 @@ class TpModelWorker:
|
||||
recv_req.name, recv_req.truncate_size
|
||||
)
|
||||
return parameter
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
result = self.model_runner.load_lora_adapter(
|
||||
recv_req.lora_name, recv_req.lora_path
|
||||
)
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
|
||||
return result
|
||||
|
||||
@@ -26,6 +26,8 @@ import torch
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
LoadLoRAAdapterReqInput,
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
return self.worker.get_weights_by_name(recv_req)
|
||||
|
||||
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||
return self.worker.load_lora_adapter(recv_req)
|
||||
|
||||
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||
return self.worker.unload_lora_adapter(recv_req)
|
||||
|
||||
def __delete__(self):
|
||||
self.input_queue.put((None, None))
|
||||
self.copy_queue.put((None, None, None))
|
||||
|
||||
@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt import debug_utils
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||
@@ -819,8 +818,47 @@ class ModelRunner:
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
)
|
||||
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||
logger.info("LoRA manager ready.")
|
||||
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||
if result.success:
|
||||
logger.info(
|
||||
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
"""Load a new lora adapter from disk or huggingface."""
|
||||
|
||||
logger.info(
|
||||
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
|
||||
|
||||
logger.info(
|
||||
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
||||
|
||||
logger.info(
|
||||
f"LoRA adapter unloading starts: name={lora_name}. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
result = self.lora_manager.unload_lora_adapter(lora_name)
|
||||
|
||||
logger.info(
|
||||
f"LoRA adapter unloading completes: name={lora_name}. "
|
||||
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def profile_max_num_token(self, total_gpu_memory: int):
|
||||
available_gpu_memory = get_available_gpu_memory(
|
||||
|
||||
@@ -503,6 +503,7 @@ class SRTRunner:
|
||||
disable_overlap_schedule: bool = False,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
torchao_config: Optional[str] = None,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
sleep_on_idle=False,
|
||||
):
|
||||
self.model_type = model_type
|
||||
@@ -539,7 +540,7 @@ class SRTRunner:
|
||||
tokenizer_path=tokenizer_path,
|
||||
enable_ep_moe=enable_ep_moe,
|
||||
disable_overlap_schedule=disable_overlap_schedule,
|
||||
cuda_graph_max_bs=4,
|
||||
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
sleep_on_idle=sleep_on_idle,
|
||||
**spec_kwargs,
|
||||
@@ -552,6 +553,12 @@ class SRTRunner:
|
||||
else:
|
||||
self.tokenizer = None
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||
return self.engine.load_lora_adapter(lora_name, lora_path)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
return self.engine.unload_lora_adapter(lora_name)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: Union[
|
||||
|
||||
Reference in New Issue
Block a user