Support limiting max loaded loras in CPU. (#8650)
This commit is contained in:
@@ -33,6 +33,8 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
|
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"* `max_loaded_loras`: If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `max-loras-per-batch`.\n",
|
||||||
|
"\n",
|
||||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
|
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
|
||||||
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
||||||
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
||||||
|
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
|
||||||
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
||||||
|
|
||||||
## Kernel backend
|
## Kernel backend
|
||||||
|
|||||||
@@ -186,3 +186,10 @@ class LoRARegistry:
|
|||||||
self._registry[lora_ref.lora_name] = lora_ref
|
self._registry[lora_ref.lora_name] = lora_ref
|
||||||
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
||||||
return lora_ref
|
return lora_ref
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_registered_loras(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns the total number of LoRA adapters currently registered.
|
||||||
|
"""
|
||||||
|
return len(self._registry)
|
||||||
|
|||||||
@@ -1097,7 +1097,7 @@ class UnloadLoRAAdapterReqInput:
|
|||||||
class LoRAUpdateResult:
|
class LoRAUpdateResult:
|
||||||
success: bool
|
success: bool
|
||||||
error_message: Optional[str] = None
|
error_message: Optional[str] = None
|
||||||
loaded_adapters: Dict[str, LoRARef] = field(default_factory=dict)
|
loaded_adapters: Optional[Dict[str, LoRARef]] = None
|
||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|||||||
@@ -1084,6 +1084,8 @@ class TokenizerManager:
|
|||||||
_: Optional[fastapi.Request] = None,
|
_: Optional[fastapi.Request] = None,
|
||||||
) -> LoadLoRAAdapterReqOutput:
|
) -> LoadLoRAAdapterReqOutput:
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
if not self.server_args.enable_lora:
|
if not self.server_args.enable_lora:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||||
@@ -1101,6 +1103,17 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async with self.lora_update_lock:
|
async with self.lora_update_lock:
|
||||||
|
if (
|
||||||
|
self.server_args.max_loaded_loras is not None
|
||||||
|
and self.lora_registry.num_registered_loras
|
||||||
|
>= self.server_args.max_loaded_loras
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
|
||||||
|
f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
|
||||||
|
"Please unload some LoRA adapters before loading new ones."
|
||||||
|
)
|
||||||
|
|
||||||
# Generate new uniquely identifiable LoRARef object.
|
# Generate new uniquely identifiable LoRARef object.
|
||||||
new_adapter = LoRARef(
|
new_adapter = LoRARef(
|
||||||
lora_name=obj.lora_name,
|
lora_name=obj.lora_name,
|
||||||
@@ -1116,6 +1129,11 @@ class TokenizerManager:
|
|||||||
await self.lora_registry.register(new_adapter)
|
await self.lora_registry.register(new_adapter)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
return LoadLoRAAdapterReqOutput(
|
||||||
|
success=False,
|
||||||
|
error_message=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
async def unload_lora_adapter(
|
async def unload_lora_adapter(
|
||||||
self,
|
self,
|
||||||
@@ -1123,6 +1141,8 @@ class TokenizerManager:
|
|||||||
_: Optional[fastapi.Request] = None,
|
_: Optional[fastapi.Request] = None,
|
||||||
) -> UnloadLoRAAdapterReqOutput:
|
) -> UnloadLoRAAdapterReqOutput:
|
||||||
self.auto_create_handle_loop()
|
self.auto_create_handle_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
if not self.server_args.enable_lora:
|
if not self.server_args.enable_lora:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
"LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
|
||||||
@@ -1154,6 +1174,8 @@ class TokenizerManager:
|
|||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
return UnloadLoRAAdapterReqOutput(success=False, rror_message=str(e))
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class ServerArgs:
|
|||||||
max_lora_rank: Optional[int] = None
|
max_lora_rank: Optional[int] = None
|
||||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||||
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
||||||
|
max_loaded_loras: Optional[int] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
|
|
||||||
@@ -1237,6 +1238,12 @@ class ServerArgs:
|
|||||||
default=8,
|
default=8,
|
||||||
help="Maximum number of adapters for a running batch, include base-only request.",
|
help="Maximum number of adapters for a running batch, include base-only request.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-loaded-loras",
|
||||||
|
type=int,
|
||||||
|
default=ServerArgs.max_loaded_loras,
|
||||||
|
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-backend",
|
"--lora-backend",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -2008,6 +2015,19 @@ class ServerArgs:
|
|||||||
self.max_lora_rank and self.lora_target_modules
|
self.max_lora_rank and self.lora_target_modules
|
||||||
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
||||||
|
|
||||||
|
# Validate max_loaded_loras
|
||||||
|
if self.max_loaded_loras is not None:
|
||||||
|
assert self.max_loaded_loras >= self.max_loras_per_batch, (
|
||||||
|
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
||||||
|
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
|
||||||
|
), (
|
||||||
|
"The number of LoRA paths should not exceed max_loaded_loras. "
|
||||||
|
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
||||||
|
)
|
||||||
|
|
||||||
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
||||||
larger_tp = max(decode_tp, prefill_tp)
|
larger_tp = max(decode_tp, prefill_tp)
|
||||||
smaller_tp = min(decode_tp, prefill_tp)
|
smaller_tp = min(decode_tp, prefill_tp)
|
||||||
|
|||||||
@@ -514,6 +514,7 @@ class SRTRunner:
|
|||||||
max_lora_rank: Optional[int] = None,
|
max_lora_rank: Optional[int] = None,
|
||||||
lora_target_modules: Optional[List[str]] = None,
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
enable_lora: Optional[bool] = None,
|
enable_lora: Optional[bool] = None,
|
||||||
|
max_loaded_loras: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -556,6 +557,7 @@ class SRTRunner:
|
|||||||
max_lora_rank=max_lora_rank,
|
max_lora_rank=max_lora_rank,
|
||||||
lora_target_modules=lora_target_modules,
|
lora_target_modules=lora_target_modules,
|
||||||
enable_lora=enable_lora,
|
enable_lora=enable_lora,
|
||||||
|
max_loaded_loras=max_loaded_loras,
|
||||||
**spec_kwargs,
|
**spec_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ class TestCase:
|
|||||||
max_lora_rank: Optional[int] = None
|
max_lora_rank: Optional[int] = None
|
||||||
lora_target_modules: Optional[List] = None
|
lora_target_modules: Optional[List] = None
|
||||||
max_new_tokens: int = 32
|
max_new_tokens: int = 32
|
||||||
|
max_loaded_loras: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||||
@@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
|
MAX_LOADED_LORAS_TESTS = [
|
||||||
|
TestCase(
|
||||||
|
description="Test max_loaded_loras limit",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=2,
|
||||||
|
max_loaded_loras=2,
|
||||||
|
all_adapters=[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
],
|
||||||
|
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
expected_error="Maximum number of loaded LoRA adapters",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
ALL_TESTS = (
|
||||||
|
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class LoRAUpdateTestSessionMode(Enum):
|
class LoRAUpdateTestSessionMode(Enum):
|
||||||
@@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase:
|
|||||||
model_path: str,
|
model_path: str,
|
||||||
lora_paths: list[str],
|
lora_paths: list[str],
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
|
max_loaded_loras: Optional[int] = None,
|
||||||
max_lora_rank: Optional[int],
|
max_lora_rank: Optional[int],
|
||||||
enable_lora: Optional[bool] = None,
|
enable_lora: Optional[bool] = None,
|
||||||
lora_target_modules: Optional[List[str]] = None,
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
@@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase:
|
|||||||
self.max_lora_rank = max_lora_rank
|
self.max_lora_rank = max_lora_rank
|
||||||
self.lora_target_modules = lora_target_modules
|
self.lora_target_modules = lora_target_modules
|
||||||
self.max_loras_per_batch = max_loras_per_batch
|
self.max_loras_per_batch = max_loras_per_batch
|
||||||
|
self.max_loaded_loras = max_loaded_loras
|
||||||
self.lora_backend = lora_backend
|
self.lora_backend = lora_backend
|
||||||
self.disable_cuda_graph = disable_cuda_graph
|
self.disable_cuda_graph = disable_cuda_graph
|
||||||
self.cuda_graph_max_bs = cuda_graph_max_bs
|
self.cuda_graph_max_bs = cuda_graph_max_bs
|
||||||
@@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
mem_fraction_static=MEM_FRACTION_STATIC,
|
mem_fraction_static=MEM_FRACTION_STATIC,
|
||||||
max_loras_per_batch=self.max_loras_per_batch,
|
max_loras_per_batch=self.max_loras_per_batch,
|
||||||
|
max_loaded_loras=self.max_loaded_loras,
|
||||||
disable_cuda_graph=self.disable_cuda_graph,
|
disable_cuda_graph=self.disable_cuda_graph,
|
||||||
cuda_graph_max_bs=self.cuda_graph_max_bs,
|
cuda_graph_max_bs=self.cuda_graph_max_bs,
|
||||||
disable_radix_cache=True,
|
disable_radix_cache=True,
|
||||||
@@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
||||||
if self.lora_target_modules is not None:
|
if self.lora_target_modules is not None:
|
||||||
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
||||||
|
if self.max_loaded_loras is not None:
|
||||||
|
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
|
||||||
|
|
||||||
# launch external server
|
# launch external server
|
||||||
self.handle = popen_launch_server(
|
self.handle = popen_launch_server(
|
||||||
@@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
mode: LoRAUpdateTestSessionMode,
|
mode: LoRAUpdateTestSessionMode,
|
||||||
base: str,
|
base: str,
|
||||||
initial_adapters: List[str],
|
initial_adapters: List[str],
|
||||||
max_loras_per_batch: int,
|
|
||||||
op_sequence: List[Operation],
|
op_sequence: List[Operation],
|
||||||
|
max_loras_per_batch: int,
|
||||||
|
max_loaded_loras: Optional[int] = None,
|
||||||
enable_lora: Optional[bool] = None,
|
enable_lora: Optional[bool] = None,
|
||||||
max_lora_rank: Optional[int] = None,
|
max_lora_rank: Optional[int] = None,
|
||||||
lora_target_modules: Optional[List[str]] = None,
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
@@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
model_path=base,
|
model_path=base,
|
||||||
lora_paths=initial_adapters,
|
lora_paths=initial_adapters,
|
||||||
max_loras_per_batch=max_loras_per_batch,
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
|
max_loaded_loras=max_loaded_loras,
|
||||||
max_lora_rank=max_lora_rank,
|
max_lora_rank=max_lora_rank,
|
||||||
lora_target_modules=lora_target_modules,
|
lora_target_modules=lora_target_modules,
|
||||||
enable_lora=enable_lora,
|
enable_lora=enable_lora,
|
||||||
@@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
enable_lora=test_case.enable_lora,
|
enable_lora=test_case.enable_lora,
|
||||||
base=test_case.base,
|
base=test_case.base,
|
||||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||||
|
max_loaded_loras=test_case.max_loaded_loras,
|
||||||
op_sequence=test_case.op_sequence,
|
op_sequence=test_case.op_sequence,
|
||||||
max_new_tokens=test_case.max_new_tokens,
|
max_new_tokens=test_case.max_new_tokens,
|
||||||
max_lora_rank=test_case.max_lora_rank,
|
max_lora_rank=test_case.max_lora_rank,
|
||||||
@@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
if x.type == OperationType.FORWARD and x.expected_error is None
|
if x.type == OperationType.FORWARD and x.expected_error is None
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if not forward_ops:
|
||||||
|
print(
|
||||||
|
f"No forward operations found in test case {case_idx}. Skipping static pass."
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
||||||
static_output = self._run_operation_sequence(
|
static_output = self._run_operation_sequence(
|
||||||
|
|||||||
Reference in New Issue
Block a user