Support limiting max loaded loras in CPU. (#8650)

This commit is contained in:
Lifu Huang
2025-08-03 00:02:23 -07:00
committed by GitHub
parent a437aa9987
commit 8675bdf246
8 changed files with 163 additions and 58 deletions

View File

@@ -70,6 +70,7 @@ class TestCase:
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None
max_new_tokens: int = 32
max_loaded_loras: Optional[int] = None
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
@@ -559,7 +560,43 @@ MAX_LORA_RANK_TESTS = [
],
),
]
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
MAX_LOADED_LORAS_TESTS = [
TestCase(
description="Test max_loaded_loras limit",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=2,
max_loaded_loras=2,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
expected_error="Maximum number of loaded LoRA adapters",
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
],
),
]
ALL_TESTS = (
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
)
class LoRAUpdateTestSessionMode(Enum):
@@ -579,6 +616,7 @@ class LoRAUpdateTestSessionBase:
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
max_lora_rank: Optional[int],
enable_lora: Optional[bool] = None,
lora_target_modules: Optional[List[str]] = None,
@@ -592,6 +630,7 @@ class LoRAUpdateTestSessionBase:
self.max_lora_rank = max_lora_rank
self.lora_target_modules = lora_target_modules
self.max_loras_per_batch = max_loras_per_batch
self.max_loaded_loras = max_loaded_loras
self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph
self.cuda_graph_max_bs = cuda_graph_max_bs
@@ -654,6 +693,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
torch_dtype=torch.float16,
mem_fraction_static=MEM_FRACTION_STATIC,
max_loras_per_batch=self.max_loras_per_batch,
max_loaded_loras=self.max_loaded_loras,
disable_cuda_graph=self.disable_cuda_graph,
cuda_graph_max_bs=self.cuda_graph_max_bs,
disable_radix_cache=True,
@@ -774,6 +814,8 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
if self.lora_target_modules is not None:
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
if self.max_loaded_loras is not None:
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
# launch external server
self.handle = popen_launch_server(
@@ -898,8 +940,9 @@ class TestLoRADynamicUpdate(CustomTestCase):
mode: LoRAUpdateTestSessionMode,
base: str,
initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
enable_lora: Optional[bool] = None,
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
@@ -917,6 +960,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
model_path=base,
lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch,
max_loaded_loras=max_loaded_loras,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
enable_lora=enable_lora,
@@ -972,6 +1016,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
enable_lora=test_case.enable_lora,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
max_loaded_loras=test_case.max_loaded_loras,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
max_lora_rank=test_case.max_lora_rank,
@@ -985,6 +1030,12 @@ class TestLoRADynamicUpdate(CustomTestCase):
if x.type == OperationType.FORWARD and x.expected_error is None
]
if not forward_ops:
print(
f"No forward operations found in test case {case_idx}. Skipping static pass."
)
continue
print("=" * 100)
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
static_output = self._run_operation_sequence(