Support limiting max loaded loras in CPU. (#8650)
This commit is contained in:
@@ -186,3 +186,10 @@ class LoRARegistry:
|
||||
self._registry[lora_ref.lora_name] = lora_ref
|
||||
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
||||
return lora_ref
|
||||
|
||||
@property
|
||||
def num_registered_loras(self) -> int:
|
||||
"""
|
||||
Returns the total number of LoRA adapters currently registered.
|
||||
"""
|
||||
return len(self._registry)
|
||||
|
||||
@@ -1097,7 +1097,7 @@ class UnloadLoRAAdapterReqInput:
|
||||
class LoRAUpdateResult:
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
||||
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
@@ -1084,76 +1084,98 @@ class TokenizerManager:
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> LoadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start load Lora adapter. Lora name=%s, path=%s",
|
||||
obj.lora_name,
|
||||
obj.lora_path,
|
||||
)
|
||||
async with self.lora_update_lock:
|
||||
if (
|
||||
self.server_args.max_loaded_loras is not None
|
||||
and self.lora_registry.num_registered_loras
|
||||
>= self.server_args.max_loaded_loras
|
||||
):
|
||||
raise ValueError(
|
||||
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
||||
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
||||
"Please unload some LoRA adapters before loading new ones."
|
||||
)
|
||||
|
||||
async with self.lora_update_lock:
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
# Generate new uniquely identifiable LoRARef object.
|
||||
new_adapter = LoRARef(
|
||||
lora_name=obj.lora_name,
|
||||
lora_path=obj.lora_path,
|
||||
)
|
||||
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
return result
|
||||
except ValueError as e:
|
||||
return LoadLoRAAdapterReqOutput(
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
# Trigger the actual loading operation at the backend processes.
|
||||
obj.lora_id = new_adapter.lora_id
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
# Register the LoRA adapter only after loading is successful.
|
||||
if result.success:
|
||||
await self.lora_registry.register(new_adapter)
|
||||
|
||||
return result
|
||||
|
||||
async def unload_lora_adapter(
|
||||
self,
|
||||
obj: UnloadLoRAAdapterReqInput,
|
||||
_: Optional[fastapi.Request] = None,
|
||||
) -> UnloadLoRAAdapterReqOutput:
|
||||
self.auto_create_handle_loop()
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
|
||||
try:
|
||||
if not self.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||
)
|
||||
|
||||
assert (
|
||||
obj.lora_name is not None
|
||||
), "lora_name must be provided to unload LoRA adapter"
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
|
||||
assert (
|
||||
obj.lora_name is not None
|
||||
), "lora_name must be provided to unload LoRA adapter"
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
|
||||
# with dp_size > 1.
|
||||
assert (
|
||||
self.server_args.dp_size == 1
|
||||
), "dp_size must be 1 for dynamic lora loading"
|
||||
logger.info(
|
||||
"Start unload Lora adapter. Lora name=%s",
|
||||
obj.lora_name,
|
||||
)
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
async with self.lora_update_lock:
|
||||
# Unregister the LoRA adapter from the registry to stop new requests for this adapter
|
||||
# from being started.
|
||||
lora_id = await self.lora_registry.unregister(obj.lora_name)
|
||||
obj.lora_id = lora_id
|
||||
|
||||
# Initiate the actual unloading operation at the backend processes only after all
|
||||
# ongoing requests using this LoRA adapter are finished.
|
||||
await self.lora_registry.wait_for_unload(lora_id)
|
||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||
|
||||
return result
|
||||
return result
|
||||
except ValueError as e:
|
||||
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
|
||||
|
||||
async def get_weights_by_name(
|
||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||
|
||||
@@ -149,6 +149,7 @@ class ServerArgs:
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
||||
max_loaded_loras: Optional[int] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
|
||||
@@ -1237,6 +1238,12 @@ class ServerArgs:
|
||||
default=8,
|
||||
help="Maximum number of adapters for a running batch, include base-only request.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loaded-loras",
|
||||
type=int,
|
||||
default=ServerArgs.max_loaded_loras,
|
||||
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-backend",
|
||||
type=str,
|
||||
@@ -2008,6 +2015,19 @@ class ServerArgs:
|
||||
self.max_lora_rank and self.lora_target_modules
|
||||
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
||||
|
||||
# Validate max_loaded_loras
|
||||
if self.max_loaded_loras is not None:
|
||||
assert self.max_loaded_loras >= self.max_loras_per_batch, (
|
||||
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
||||
)
|
||||
assert (
|
||||
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
|
||||
), (
|
||||
"The number of LoRA paths should not exceed max_loaded_loras. "
|
||||
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
||||
)
|
||||
|
||||
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||
larger_tp = max(decode_tp, prefill_tp)
|
||||
smaller_tp = min(decode_tp, prefill_tp)
|
||||
|
||||
@@ -514,6 +514,7 @@ class SRTRunner:
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
enable_lora: Optional[bool] = None,
|
||||
max_loaded_loras: Optional[int] = None,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -556,6 +557,7 @@ class SRTRunner:
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
enable_lora=enable_lora,
|
||||
max_loaded_loras=max_loaded_loras,
|
||||
**spec_kwargs,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user