Support dynamic LoRA loading / unloading in engine/server API (#7446)
This commit is contained in:
@@ -48,6 +48,14 @@ class EngineBase(ABC):
|
|||||||
"""Update model weights with in-memory tensor data."""
|
"""Update model weights with in-memory tensor data."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||||
|
"""Load a new LoRA adapter without re-launching the engine."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""Unload a LoRA adapter without re-launching the engine."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def release_memory_occupation(self):
|
def release_memory_occupation(self):
|
||||||
"""Release GPU memory occupation temporarily."""
|
"""Release GPU memory occupation temporarily."""
|
||||||
|
|||||||
@@ -48,10 +48,12 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
ImageDataItem,
|
ImageDataItem,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
ReleaseMemoryOccupationReqInput,
|
ReleaseMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
RpcReqInput,
|
RpcReqInput,
|
||||||
RpcReqOutput,
|
RpcReqOutput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
@@ -478,6 +480,29 @@ class Engine(EngineBase):
|
|||||||
self.tokenizer_manager.get_weights_by_name(obj, None)
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||||
|
"""Load a new LoRA adapter without re-launching the engine."""
|
||||||
|
|
||||||
|
obj = LoadLoRAAdapterReqInput(
|
||||||
|
lora_name=lora_name,
|
||||||
|
lora_path=lora_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.load_lora_adapter(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""Unload a LoRA adapter without re-launching the engine."""
|
||||||
|
|
||||||
|
obj = UnloadLoRAAdapterReqInput(lora_name=lora_name)
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.unload_lora_adapter(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
def release_memory_occupation(self, tags: Optional[List[str]] = None):
|
||||||
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
obj = ReleaseMemoryOccupationReqInput(tags=tags)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
GenerateReqInput,
|
GenerateReqInput,
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
ParseFunctionCallReq,
|
ParseFunctionCallReq,
|
||||||
ProfileReqInput,
|
ProfileReqInput,
|
||||||
@@ -80,6 +81,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
SeparateReasoningReqInput,
|
SeparateReasoningReqInput,
|
||||||
SetInternalStateReq,
|
SetInternalStateReq,
|
||||||
SlowDownReqInput,
|
SlowDownReqInput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
@@ -595,6 +597,40 @@ async def slow_down(obj: SlowDownReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/load_lora_adapter", methods=["POST"])
|
||||||
|
async def load_lora_adapter(obj: LoadLoRAAdapterReqInput, request: Request):
|
||||||
|
"""Load a new LoRA adapter without re-launching the server."""
|
||||||
|
result = await _global_state.tokenizer_manager.load_lora_adapter(obj, request)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return ORJSONResponse(
|
||||||
|
result,
|
||||||
|
status_code=HTTPStatus.OK,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(
|
||||||
|
result,
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/unload_lora_adapter", methods=["POST"])
|
||||||
|
async def unload_lora_adapter(obj: UnloadLoRAAdapterReqInput, request: Request):
|
||||||
|
"""Load a new LoRA adapter without re-launching the server."""
|
||||||
|
result = await _global_state.tokenizer_manager.unload_lora_adapter(obj, request)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
return ORJSONResponse(
|
||||||
|
result,
|
||||||
|
status_code=HTTPStatus.OK,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(
|
||||||
|
result,
|
||||||
|
status_code=HTTPStatus.BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/open_session", methods=["GET", "POST"])
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
||||||
async def open_session(obj: OpenSessionReqInput, request: Request):
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
||||||
"""Open a session, and return its unique session id."""
|
"""Open a session, and return its unique session id."""
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
|||||||
self.layers: List[LoRALayer] = nn.ModuleList(
|
self.layers: List[LoRALayer] = nn.ModuleList(
|
||||||
[
|
[
|
||||||
LoRALayer(config, base_hf_config)
|
LoRALayer(config, base_hf_config)
|
||||||
for i in range(base_hf_config.num_hidden_layers)
|
for _ in range(base_hf_config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.weights[name] = loaded_weight.cpu()
|
self.weights[name] = loaded_weight.cpu()
|
||||||
|
|
||||||
# stack kv_proj and gate_up_proj
|
# normalize kv_proj and gate_up_proj
|
||||||
for i in range(self.base_hf_config.num_hidden_layers):
|
for layer in self.layers:
|
||||||
layer = self.layers[i]
|
weight_names = list(layer.weights.keys())
|
||||||
weight_names = [name for name, _ in layer.weights.items()]
|
|
||||||
self.normalize_qkv_proj(weight_names, layer.weights)
|
self.normalize_qkv_proj(weight_names, layer.weights)
|
||||||
self.normalize_gate_up_proj(weight_names, layer.weights)
|
self.normalize_gate_up_proj(weight_names, layer.weights)
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
|||||||
get_normalized_lora_weight_names,
|
get_normalized_lora_weight_names,
|
||||||
get_weight_name,
|
get_weight_name,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.utils import replace_submodule
|
from sglang.srt.utils import replace_submodule
|
||||||
|
|
||||||
@@ -98,44 +99,96 @@ class LoRAManager:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_lora_adapters(self, lora_paths: Dict[str, str]):
|
def create_lora_update_result(
|
||||||
|
self, success: bool, error_message: str = ""
|
||||||
|
) -> LoRAUpdateResult:
|
||||||
|
return LoRAUpdateResult(
|
||||||
|
success=success,
|
||||||
|
error_message=error_message,
|
||||||
|
loaded_adapters={
|
||||||
|
name: config.path for name, config in self.configs.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
||||||
"""
|
"""
|
||||||
Load LoRA adapters from the specified paths.
|
Load LoRA adapters from the specified paths.
|
||||||
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
||||||
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
results = []
|
||||||
for lora_name, lora_path in lora_paths.items():
|
for lora_name, lora_path in lora_paths.items():
|
||||||
if lora_name in self.loras:
|
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
||||||
logger.warning(
|
results.append(result)
|
||||||
f"LoRA adapter {lora_name} is already loaded."
|
|
||||||
"If you want to reload it, please unload it first."
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
|
||||||
|
|
||||||
self.update_state_from_configs()
|
self.update_state_from_configs()
|
||||||
|
|
||||||
def unload_lora_adapters(self, lora_names: Set[str]):
|
return self.create_lora_update_result(
|
||||||
|
success=all(result.success for result in results),
|
||||||
|
error_message="\n".join(
|
||||||
|
result.error_message for result in results if not result.success
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_lora_adapter(
|
||||||
|
self, lora_name: str, lora_path: str, update_state: bool = True
|
||||||
|
) -> LoRAUpdateResult:
|
||||||
|
"""
|
||||||
|
Load a single LoRA adapter from the specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lora_name (str): The name of the LoRA adapter.
|
||||||
|
lora_path (str): The file path to the LoRA adapter.
|
||||||
|
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success = True
|
||||||
|
error_message = ""
|
||||||
|
|
||||||
|
if lora_name in self.loras:
|
||||||
|
success = False
|
||||||
|
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||||
|
except Exception as e:
|
||||||
|
success = False
|
||||||
|
error_message = (
|
||||||
|
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if update_state:
|
||||||
|
self.update_state_from_configs()
|
||||||
|
|
||||||
|
return self.create_lora_update_result(
|
||||||
|
success=success,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||||
"""
|
"""
|
||||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||||
delete the corresponding LoRA modules.
|
delete the corresponding LoRA modules.
|
||||||
|
|
||||||
Args:
|
|
||||||
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
|
||||||
"""
|
"""
|
||||||
for lora_name in lora_names:
|
|
||||||
if lora_name in self.loras:
|
success = True
|
||||||
del self.configs[lora_name]
|
error_message = ""
|
||||||
else:
|
if lora_name in self.loras:
|
||||||
logger.warning(f"LoRA adapter {lora_name} is not loaded.")
|
del self.configs[lora_name]
|
||||||
|
else:
|
||||||
|
error_message = f"LoRA adapter {lora_name} is not loaded."
|
||||||
|
success = False
|
||||||
|
|
||||||
self.update_state_from_configs()
|
self.update_state_from_configs()
|
||||||
|
|
||||||
|
return self.create_lora_update_result(
|
||||||
|
success=success,
|
||||||
|
error_message=error_message,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
||||||
# load active loras into lora memory pool
|
# load active loras into lora memory pool
|
||||||
cur_uids = set(forward_batch.lora_paths)
|
cur_uids = set(forward_batch.lora_paths)
|
||||||
@@ -372,8 +425,8 @@ class LoRAManager:
|
|||||||
lora_adapter.initialize_weights()
|
lora_adapter.initialize_weights()
|
||||||
self.loras[name] = lora_adapter
|
self.loras[name] = lora_adapter
|
||||||
|
|
||||||
# Clean up unused LoRA adapters
|
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
||||||
for name in self.loras:
|
for name in list(self.loras):
|
||||||
if name not in self.configs:
|
if name not in self.configs:
|
||||||
logger.info(f"Unloading LoRA adapter {name}")
|
logger.info(f"Unloading LoRA adapter {name}")
|
||||||
del self.loras[name]
|
del self.loras[name]
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import copy
|
|||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
from sglang.srt.multimodal.mm_utils import has_valid_data
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
||||||
|
|
||||||
@@ -1002,3 +1002,27 @@ class RpcReqInput:
|
|||||||
class RpcReqOutput:
|
class RpcReqOutput:
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadLoRAAdapterReqInput:
|
||||||
|
# The name of the lora module to newly loaded.
|
||||||
|
lora_name: str
|
||||||
|
# The path of loading.
|
||||||
|
lora_path: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UnloadLoRAAdapterReqInput:
|
||||||
|
# The name of lora module to unload.
|
||||||
|
lora_name: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAUpdateResult:
|
||||||
|
success: bool
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
loaded_adapters: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|||||||
@@ -82,6 +82,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
|
LoadLoRAAdapterReqOutput,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -99,6 +101,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
SlowDownReqOutput,
|
SlowDownReqOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
|
UnloadLoRAAdapterReqOutput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
@@ -519,6 +523,8 @@ class Scheduler(
|
|||||||
(SetInternalStateReq, self.set_internal_state),
|
(SetInternalStateReq, self.set_internal_state),
|
||||||
(RpcReqInput, self.handle_rpc_request),
|
(RpcReqInput, self.handle_rpc_request),
|
||||||
(ExpertDistributionReq, self.expert_distribution_handle),
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
||||||
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
||||||
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2241,6 +2247,36 @@ class Scheduler(
|
|||||||
logger.error(message)
|
logger.error(message)
|
||||||
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
||||||
|
|
||||||
|
def load_lora_adapter(
|
||||||
|
self, recv_req: LoadLoRAAdapterReqInput
|
||||||
|
) -> LoadLoRAAdapterReqOutput:
|
||||||
|
"""In-place loading a new lora adapter from disk or huggingface."""
|
||||||
|
|
||||||
|
result = self.tp_worker.load_lora_adapter(recv_req)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert flush_cache_success, "Cache flush failed after loading lora adapter."
|
||||||
|
else:
|
||||||
|
logger.error(result.error_message)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def unload_lora_adapter(
|
||||||
|
self, recv_req: UnloadLoRAAdapterReqInput
|
||||||
|
) -> UnloadLoRAAdapterReqOutput:
|
||||||
|
"""Unload the lora adapter."""
|
||||||
|
|
||||||
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
||||||
|
|
||||||
|
if result.success:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert (
|
||||||
|
flush_cache_success
|
||||||
|
), "Cache flush failed after unloading LoRA weights"
|
||||||
|
else:
|
||||||
|
logger.error(result.error_message)
|
||||||
|
return result
|
||||||
|
|
||||||
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
||||||
"""Initialize the online model parameter update group."""
|
"""Initialize the online model parameter update group."""
|
||||||
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ from sglang.srt.managers.io_struct import (
|
|||||||
HealthCheckOutput,
|
HealthCheckOutput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
InitWeightsUpdateGroupReqOutput,
|
InitWeightsUpdateGroupReqOutput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
|
LoadLoRAAdapterReqOutput,
|
||||||
|
LoRAUpdateResult,
|
||||||
OpenSessionReqInput,
|
OpenSessionReqInput,
|
||||||
OpenSessionReqOutput,
|
OpenSessionReqOutput,
|
||||||
ProfileReq,
|
ProfileReq,
|
||||||
@@ -99,6 +102,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
SlowDownReqOutput,
|
SlowDownReqOutput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
|
UnloadLoRAAdapterReqOutput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
@@ -311,6 +316,9 @@ class TokenizerManager:
|
|||||||
self.expert_distribution_communicator = _Communicator(
|
self.expert_distribution_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.update_lora_adapter_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
|
|
||||||
self._result_dispatcher = TypeBasedDispatcher(
|
self._result_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
@@ -377,6 +385,10 @@ class TokenizerManager:
|
|||||||
ExpertDistributionReqOutput,
|
ExpertDistributionReqOutput,
|
||||||
self.expert_distribution_communicator.handle_recv,
|
self.expert_distribution_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
LoRAUpdateResult,
|
||||||
|
self.update_lora_adapter_communicator.handle_recv,
|
||||||
|
),
|
||||||
(HealthCheckOutput, lambda x: None),
|
(HealthCheckOutput, lambda x: None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -960,6 +972,49 @@ class TokenizerManager:
|
|||||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
obj: LoadLoRAAdapterReqInput,
|
||||||
|
_: Optional[fastapi.Request] = None,
|
||||||
|
) -> LoadLoRAAdapterReqOutput:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
|
||||||
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||||
|
# with dp_size > 1.
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for dynamic lora loading"
|
||||||
|
logger.info(
|
||||||
|
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||||
|
obj.lora_name,
|
||||||
|
obj.lora_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.model_update_lock.writer_lock:
|
||||||
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def unload_lora_adapter(
|
||||||
|
self,
|
||||||
|
obj: UnloadLoRAAdapterReqInput,
|
||||||
|
_: Optional[fastapi.Request] = None,
|
||||||
|
) -> UnloadLoRAAdapterReqOutput:
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
|
||||||
|
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||||
|
# with dp_size > 1.
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1
|
||||||
|
), "dp_size must be 1 for dynamic lora loading"
|
||||||
|
logger.info(
|
||||||
|
"Start unload Lora adapter. Lora name=%s",
|
||||||
|
obj.lora_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.model_update_lock.writer_lock:
|
||||||
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
return result
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
@@ -275,3 +277,13 @@ class TpModelWorker:
|
|||||||
recv_req.name, recv_req.truncate_size
|
recv_req.name, recv_req.truncate_size
|
||||||
)
|
)
|
||||||
return parameter
|
return parameter
|
||||||
|
|
||||||
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||||
|
result = self.model_runner.load_lora_adapter(
|
||||||
|
recv_req.lora_name, recv_req.lora_path
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||||
|
result = self.model_runner.unload_lora_adapter(recv_req.lora_name)
|
||||||
|
return result
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ import torch
|
|||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
GetWeightsByNameReqInput,
|
GetWeightsByNameReqInput,
|
||||||
InitWeightsUpdateGroupReqInput,
|
InitWeightsUpdateGroupReqInput,
|
||||||
|
LoadLoRAAdapterReqInput,
|
||||||
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
@@ -268,6 +270,12 @@ class TpModelWorkerClient:
|
|||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
return self.worker.get_weights_by_name(recv_req)
|
return self.worker.get_weights_by_name(recv_req)
|
||||||
|
|
||||||
|
def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
|
||||||
|
return self.worker.load_lora_adapter(recv_req)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
||||||
|
return self.worker.unload_lora_adapter(recv_req)
|
||||||
|
|
||||||
def __delete__(self):
|
def __delete__(self):
|
||||||
self.input_queue.put((None, None))
|
self.input_queue.put((None, None))
|
||||||
self.copy_queue.put((None, None, None))
|
self.copy_queue.put((None, None, None))
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from sglang.srt import debug_utils
|
|
||||||
from sglang.srt.configs.device_config import DeviceConfig
|
from sglang.srt.configs.device_config import DeviceConfig
|
||||||
from sglang.srt.configs.load_config import LoadConfig
|
from sglang.srt.configs.load_config import LoadConfig
|
||||||
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
||||||
@@ -819,8 +818,47 @@ class ModelRunner:
|
|||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
)
|
)
|
||||||
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||||
logger.info("LoRA manager ready.")
|
if result.success:
|
||||||
|
logger.info(
|
||||||
|
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||||
|
"""Load a new lora adapter from disk or huggingface."""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
|
||||||
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
|
||||||
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LoRA adapter unloading starts: name={lora_name}. "
|
||||||
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.lora_manager.unload_lora_adapter(lora_name)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"LoRA adapter unloading completes: name={lora_name}. "
|
||||||
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def profile_max_num_token(self, total_gpu_memory: int):
|
def profile_max_num_token(self, total_gpu_memory: int):
|
||||||
available_gpu_memory = get_available_gpu_memory(
|
available_gpu_memory = get_available_gpu_memory(
|
||||||
|
|||||||
@@ -503,6 +503,7 @@ class SRTRunner:
|
|||||||
disable_overlap_schedule: bool = False,
|
disable_overlap_schedule: bool = False,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
torchao_config: Optional[str] = None,
|
torchao_config: Optional[str] = None,
|
||||||
|
cuda_graph_max_bs: int = 4,
|
||||||
sleep_on_idle=False,
|
sleep_on_idle=False,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
@@ -539,7 +540,7 @@ class SRTRunner:
|
|||||||
tokenizer_path=tokenizer_path,
|
tokenizer_path=tokenizer_path,
|
||||||
enable_ep_moe=enable_ep_moe,
|
enable_ep_moe=enable_ep_moe,
|
||||||
disable_overlap_schedule=disable_overlap_schedule,
|
disable_overlap_schedule=disable_overlap_schedule,
|
||||||
cuda_graph_max_bs=4,
|
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
sleep_on_idle=sleep_on_idle,
|
sleep_on_idle=sleep_on_idle,
|
||||||
**spec_kwargs,
|
**spec_kwargs,
|
||||||
@@ -552,6 +553,12 @@ class SRTRunner:
|
|||||||
else:
|
else:
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
||||||
|
return self.engine.load_lora_adapter(lora_name, lora_path)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
return self.engine.unload_lora_adapter(lora_name)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
prompts: Union[
|
prompts: Union[
|
||||||
|
|||||||
616
test/srt/models/lora/test_lora_update.py
Normal file
616
test/srt/models/lora/test_lora_update.py
Normal file
@@ -0,0 +1,616 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import multiprocessing as mp
|
||||||
|
import unittest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.runners import SRTRunner
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
PROMPTS = [
|
||||||
|
"SGL is a",
|
||||||
|
"AI is a field of computer science focused on",
|
||||||
|
"Computer science is the study of",
|
||||||
|
"Write a short story.",
|
||||||
|
"What are the main components of a computer?",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OperationType(Enum):
|
||||||
|
LOAD = "load"
|
||||||
|
UNLOAD = "unload"
|
||||||
|
NOOP = "noop"
|
||||||
|
FORWARD = "forward"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Operation:
|
||||||
|
type: OperationType
|
||||||
|
data: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TestCase:
|
||||||
|
base: str
|
||||||
|
max_loras_per_batch: int
|
||||||
|
all_adapters: List[str]
|
||||||
|
initial_adapters: List[str]
|
||||||
|
op_sequence: List[Operation]
|
||||||
|
max_new_tokens: int = 32
|
||||||
|
|
||||||
|
|
||||||
|
def create_batch_data(adapters: Union[str, list]) -> dict:
|
||||||
|
if not isinstance(adapters, list):
|
||||||
|
adapters = [adapters]
|
||||||
|
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||||
|
|
||||||
|
|
||||||
|
TEST_CASES = [
|
||||||
|
# basic test, no eviction
|
||||||
|
TestCase(
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
all_adapters=[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
],
|
||||||
|
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
# Eviction
|
||||||
|
TestCase(
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=1,
|
||||||
|
all_adapters=[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
],
|
||||||
|
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAUpdateTestSessionMode(Enum):
|
||||||
|
ENGINE = "engine"
|
||||||
|
SERVER = "server"
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAUpdateTestSessionBase:
|
||||||
|
"""
|
||||||
|
Base context manager for testing LoRA adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
testcase: Optional[TestCase],
|
||||||
|
model_path: str,
|
||||||
|
lora_paths: list[str],
|
||||||
|
max_loras_per_batch: int = 1,
|
||||||
|
lora_backend: str = "triton",
|
||||||
|
disable_cuda_graph: bool = False,
|
||||||
|
cuda_graph_max_bs: int = 4,
|
||||||
|
):
|
||||||
|
self.testcase = testcase
|
||||||
|
self.model_path = model_path
|
||||||
|
self.lora_paths = lora_paths
|
||||||
|
self.max_loras_per_batch = max_loras_per_batch
|
||||||
|
self.lora_backend = lora_backend
|
||||||
|
self.disable_cuda_graph = disable_cuda_graph
|
||||||
|
self.cuda_graph_max_bs = cuda_graph_max_bs
|
||||||
|
|
||||||
|
self.expected_adapters = set(lora_paths)
|
||||||
|
self.handle = None # Will be set in __enter__
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
# Don't suppress exceptions by default
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Load a LoRA adapter by name and path.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement load_lora_adapter")
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""
|
||||||
|
Unload a LoRA adapter by name.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement unload_lora_adapter")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
lora_paths: List[str],
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError("Subclasses must implement forward")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||||
|
"""
|
||||||
|
Context manager for testing LoRA adapters with in-process engine.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# in-process runner
|
||||||
|
self.handle = SRTRunner(
|
||||||
|
model_path=self.model_path,
|
||||||
|
model_type="generation",
|
||||||
|
lora_paths=self.lora_paths,
|
||||||
|
lora_backend=self.lora_backend,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
max_loras_per_batch=self.max_loras_per_batch,
|
||||||
|
disable_cuda_graph=self.disable_cuda_graph,
|
||||||
|
cuda_graph_max_bs=self.cuda_graph_max_bs,
|
||||||
|
disable_radix_cache=True,
|
||||||
|
)
|
||||||
|
self.handle.__enter__()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.handle is not None:
|
||||||
|
# delegate cleanup to SRTRunner
|
||||||
|
return self.handle.__exit__(exc_type, exc_val, exc_tb)
|
||||||
|
# don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Load a LoRA adapter by name and path.
|
||||||
|
"""
|
||||||
|
if lora_path is None:
|
||||||
|
lora_path = lora_name
|
||||||
|
|
||||||
|
self.expected_adapters.add(lora_name)
|
||||||
|
|
||||||
|
response = self.handle.load_lora_adapter(
|
||||||
|
lora_name=lora_name,
|
||||||
|
lora_path=lora_path,
|
||||||
|
)
|
||||||
|
self.testcase.assertTrue(response.success)
|
||||||
|
loaded_adapters = set(response.loaded_adapters)
|
||||||
|
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""
|
||||||
|
Unload a LoRA adapter by name.
|
||||||
|
"""
|
||||||
|
self.expected_adapters.remove(lora_name)
|
||||||
|
|
||||||
|
response = self.handle.unload_lora_adapter(
|
||||||
|
lora_name=lora_name,
|
||||||
|
)
|
||||||
|
self.testcase.assertTrue(response.success)
|
||||||
|
loaded_adapters = set(response.loaded_adapters)
|
||||||
|
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
lora_paths: List[str],
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
|
"""
|
||||||
|
response = self.handle.batch_forward(
|
||||||
|
prompts=prompts,
|
||||||
|
lora_paths=lora_paths,
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
output_strs = response.output_strs
|
||||||
|
|
||||||
|
print(f"output_strs: {output_strs}")
|
||||||
|
return output_strs
|
||||||
|
|
||||||
|
|
||||||
|
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||||
|
"""
|
||||||
|
Context manager for testing LoRA adapters with standalone server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
other_args = [
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
str(self.cuda_graph_max_bs),
|
||||||
|
"--lora-paths",
|
||||||
|
*self.lora_paths,
|
||||||
|
"--max-loras-per-batch",
|
||||||
|
str(self.max_loras_per_batch),
|
||||||
|
"--lora-backend",
|
||||||
|
self.lora_backend,
|
||||||
|
"--disable-radix-cache",
|
||||||
|
"--random-seed",
|
||||||
|
"42",
|
||||||
|
"--max-running-request",
|
||||||
|
"1",
|
||||||
|
]
|
||||||
|
if self.disable_cuda_graph:
|
||||||
|
other_args.append("--disable-cuda-graph")
|
||||||
|
|
||||||
|
# launch external server
|
||||||
|
self.handle = popen_launch_server(
|
||||||
|
self.model_path,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=other_args,
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
if self.handle is not None:
|
||||||
|
kill_process_tree(self.handle.pid)
|
||||||
|
# don't suppress exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Load a LoRA adapter by name and path.
|
||||||
|
"""
|
||||||
|
if lora_path is None:
|
||||||
|
lora_path = lora_name
|
||||||
|
|
||||||
|
self.expected_adapters.add(lora_name)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
||||||
|
json={"lora_name": lora_name, "lora_path": lora_path},
|
||||||
|
)
|
||||||
|
self.testcase.assertTrue(response.ok)
|
||||||
|
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||||
|
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
|
"""
|
||||||
|
Unload a LoRA adapter by name.
|
||||||
|
"""
|
||||||
|
self.expected_adapters.remove(lora_name)
|
||||||
|
|
||||||
|
response = requests.post(
|
||||||
|
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
|
||||||
|
json={"lora_name": lora_name},
|
||||||
|
)
|
||||||
|
self.testcase.assertTrue(response.ok)
|
||||||
|
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||||
|
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
lora_paths: List[str],
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
|
"""
|
||||||
|
response = requests.post(
|
||||||
|
DEFAULT_URL_FOR_TEST + "/generate",
|
||||||
|
json={
|
||||||
|
"text": prompts,
|
||||||
|
"lora_path": lora_paths,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"top_k": 1,
|
||||||
|
"max_new_tokens": max_new_tokens,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.testcase.assertTrue(response.ok)
|
||||||
|
output_strs = [r["text"] for r in response.json()]
|
||||||
|
|
||||||
|
print(f"output_strs: {output_strs}")
|
||||||
|
return output_strs
|
||||||
|
|
||||||
|
|
||||||
|
# Factory function to create the appropriate LoRA test session based on mode
|
||||||
|
def LoRAUpdateTestSession(
|
||||||
|
*,
|
||||||
|
testcase: Optional[TestCase],
|
||||||
|
mode: LoRAUpdateTestSessionMode,
|
||||||
|
model_path: str,
|
||||||
|
lora_paths: list[str],
|
||||||
|
max_loras_per_batch: int = 1,
|
||||||
|
lora_backend: str = "triton",
|
||||||
|
disable_cuda_graph: bool = False,
|
||||||
|
cuda_graph_max_bs: int = 4,
|
||||||
|
):
|
||||||
|
common_kwargs = {
|
||||||
|
"testcase": testcase,
|
||||||
|
"model_path": model_path,
|
||||||
|
"lora_paths": lora_paths,
|
||||||
|
"max_loras_per_batch": max_loras_per_batch,
|
||||||
|
"lora_backend": lora_backend,
|
||||||
|
"disable_cuda_graph": disable_cuda_graph,
|
||||||
|
"cuda_graph_max_bs": cuda_graph_max_bs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
||||||
|
return LoRAUpdateEngineTestSession(**common_kwargs)
|
||||||
|
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
||||||
|
return LoRAUpdateServerTestSession(**common_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized mode: {mode!r}")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoRADynamicUpdate(CustomTestCase):
|
||||||
|
"""
|
||||||
|
This test case verifies that the SRT runner can dynamically load and unload LoRA adapters
|
||||||
|
during a sequence of operations, and that the outputs of forward passes with dynamically loaded
|
||||||
|
adapters match the outputs of forward passes with statically loaded adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _repeat_each(lst, n):
|
||||||
|
return [x for x in lst for _ in range(n)]
|
||||||
|
|
||||||
|
def _run_operation_sequence(
|
||||||
|
self,
|
||||||
|
mode: LoRAUpdateTestSessionMode,
|
||||||
|
base: str,
|
||||||
|
initial_adapters: List[str],
|
||||||
|
max_loras_per_batch: int,
|
||||||
|
op_sequence: List[Operation],
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
) -> List[tuple]:
|
||||||
|
"""
|
||||||
|
Runs a sequence of operations on the SRT runner, including loading and unloading LoRA adapters,
|
||||||
|
and performing forward passes with the current set of loaded adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
forward_outputs = []
|
||||||
|
with LoRAUpdateTestSession(
|
||||||
|
testcase=self,
|
||||||
|
mode=mode,
|
||||||
|
model_path=base,
|
||||||
|
lora_paths=initial_adapters,
|
||||||
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
|
) as session:
|
||||||
|
for op in op_sequence:
|
||||||
|
op_type = op.type
|
||||||
|
data = op.data
|
||||||
|
print("-" * 100)
|
||||||
|
print(
|
||||||
|
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
||||||
|
)
|
||||||
|
if op_type == OperationType.LOAD:
|
||||||
|
result = session.load_lora_adapter(
|
||||||
|
lora_name=data,
|
||||||
|
lora_path=data,
|
||||||
|
)
|
||||||
|
elif op_type == OperationType.UNLOAD:
|
||||||
|
result = session.unload_lora_adapter(
|
||||||
|
lora_name=data,
|
||||||
|
)
|
||||||
|
elif op_type == OperationType.FORWARD:
|
||||||
|
prompts, adapters = zip(*data)
|
||||||
|
result = session.forward(
|
||||||
|
prompts=list(prompts),
|
||||||
|
lora_paths=list(adapters),
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
)
|
||||||
|
forward_outputs.append(result)
|
||||||
|
|
||||||
|
return forward_outputs
|
||||||
|
|
||||||
|
def test_dynamic_adapter_updates(self):
|
||||||
|
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
||||||
|
for mode in [
|
||||||
|
LoRAUpdateTestSessionMode.SERVER,
|
||||||
|
LoRAUpdateTestSessionMode.ENGINE,
|
||||||
|
]:
|
||||||
|
print("=" * 100)
|
||||||
|
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
||||||
|
print("=" * 100)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||||
|
)
|
||||||
|
# Test dynamic loading of adapters
|
||||||
|
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
|
||||||
|
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
|
||||||
|
dynamic_output = self._run_operation_sequence(
|
||||||
|
mode=mode,
|
||||||
|
initial_adapters=test_case.initial_adapters,
|
||||||
|
base=test_case.base,
|
||||||
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||||
|
op_sequence=test_case.op_sequence,
|
||||||
|
max_new_tokens=test_case.max_new_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# static loading
|
||||||
|
forward_ops = [
|
||||||
|
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
|
||||||
|
]
|
||||||
|
|
||||||
|
print("=" * 100)
|
||||||
|
print(
|
||||||
|
f"\n--- Running static pass with {len(forward_ops)} operations ---"
|
||||||
|
)
|
||||||
|
static_output = self._run_operation_sequence(
|
||||||
|
mode=mode,
|
||||||
|
initial_adapters=test_case.all_adapters,
|
||||||
|
base=test_case.base,
|
||||||
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||||
|
op_sequence=forward_ops,
|
||||||
|
max_new_tokens=test_case.max_new_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Dynamic output: {dynamic_output}")
|
||||||
|
print(f"Static output: {static_output}")
|
||||||
|
print("=" * 100)
|
||||||
|
self.assertEqual(
|
||||||
|
len(dynamic_output),
|
||||||
|
len(static_output),
|
||||||
|
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||||
|
)
|
||||||
|
for i, (dynamic, static) in enumerate(
|
||||||
|
zip(dynamic_output, static_output), start=1
|
||||||
|
):
|
||||||
|
self.assertEqual(
|
||||||
|
len(dynamic),
|
||||||
|
len(static),
|
||||||
|
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||||
|
)
|
||||||
|
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||||
|
d_out = d_out.strip()
|
||||||
|
s_out = s_out.strip()
|
||||||
|
self.assertEqual(
|
||||||
|
d_out,
|
||||||
|
s_out,
|
||||||
|
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
mp.set_start_method("spawn")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
unittest.main(warnings="ignore")
|
||||||
@@ -17,6 +17,7 @@ suites = {
|
|||||||
TestFile("models/lora/test_lora_backend.py", 99),
|
TestFile("models/lora/test_lora_backend.py", 99),
|
||||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||||
|
TestFile("models/lora/test_lora_update.py", 400),
|
||||||
TestFile("models/test_embedding_models.py", 73),
|
TestFile("models/test_embedding_models.py", 73),
|
||||||
# TestFile("models/test_clip_models.py", 52),
|
# TestFile("models/test_clip_models.py", 52),
|
||||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||||
|
|||||||
Reference in New Issue
Block a user