Support limiting max loaded loras in CPU. (#8650)
This commit is contained in:
@@ -70,6 +70,7 @@ class TestCase:
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[List] = None
|
||||
max_new_tokens: int = 32
|
||||
max_loaded_loras: Optional[int] = None
|
||||
|
||||
|
||||
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||
@@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [
|
||||
],
|
||||
),
|
||||
]
|
||||
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
|
||||
MAX_LOADED_LORAS_TESTS = [
|
||||
TestCase(
|
||||
description="Test max_loaded_loras limit",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=2,
|
||||
max_loaded_loras=2,
|
||||
all_adapters=[
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
],
|
||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
expected_error="Maximum number of loaded LoRA adapters",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
ALL_TESTS = (
|
||||
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
|
||||
)
|
||||
|
||||
|
||||
class LoRAUpdateTestSessionMode(Enum):
|
||||
@@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase:
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
max_loras_per_batch: int,
|
||||
max_loaded_loras: Optional[int] = None,
|
||||
max_lora_rank: Optional[int],
|
||||
enable_lora: Optional[bool] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
@@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase:
|
||||
self.max_lora_rank = max_lora_rank
|
||||
self.lora_target_modules = lora_target_modules
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.max_loaded_loras = max_loaded_loras
|
||||
self.lora_backend = lora_backend
|
||||
self.disable_cuda_graph = disable_cuda_graph
|
||||
self.cuda_graph_max_bs = cuda_graph_max_bs
|
||||
@@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
torch_dtype=torch.float16,
|
||||
mem_fraction_static=MEM_FRACTION_STATIC,
|
||||
max_loras_per_batch=self.max_loras_per_batch,
|
||||
max_loaded_loras=self.max_loaded_loras,
|
||||
disable_cuda_graph=self.disable_cuda_graph,
|
||||
cuda_graph_max_bs=self.cuda_graph_max_bs,
|
||||
disable_radix_cache=True,
|
||||
@@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
||||
if self.lora_target_modules is not None:
|
||||
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
||||
if self.max_loaded_loras is not None:
|
||||
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
|
||||
|
||||
# launch external server
|
||||
self.handle = popen_launch_server(
|
||||
@@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
mode: LoRAUpdateTestSessionMode,
|
||||
base: str,
|
||||
initial_adapters: List[str],
|
||||
max_loras_per_batch: int,
|
||||
op_sequence: List[Operation],
|
||||
max_loras_per_batch: int,
|
||||
max_loaded_loras: Optional[int] = None,
|
||||
enable_lora: Optional[bool] = None,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
@@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
model_path=base,
|
||||
lora_paths=initial_adapters,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
max_loaded_loras=max_loaded_loras,
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
enable_lora=enable_lora,
|
||||
@@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
enable_lora=test_case.enable_lora,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
max_loaded_loras=test_case.max_loaded_loras,
|
||||
op_sequence=test_case.op_sequence,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
max_lora_rank=test_case.max_lora_rank,
|
||||
@@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
if x.type == OperationType.FORWARD and x.expected_error is None
|
||||
]
|
||||
|
||||
if not forward_ops:
|
||||
print(
|
||||
f"No forward operations found in test case {case_idx}. Skipping static pass."
|
||||
)
|
||||
continue
|
||||
|
||||
print("=" * 100)
|
||||
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
||||
static_output = self._run_operation_sequence(
|
||||
|
||||
Reference in New Issue
Block a user