Support GPU pinning for LoRA (#8697)
This commit is contained in:
@@ -231,88 +231,6 @@ BASIC_TESTS = [
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="dynamic lora update with evictions",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=1,
|
||||
all_adapters=[
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
],
|
||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
TARGET_MODULE_TESTS = [
|
||||
TestCase(
|
||||
@@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [
|
||||
],
|
||||
),
|
||||
]
|
||||
EVICTION_TESTS = [
|
||||
TestCase(
|
||||
description="dynamic lora update with evictions",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=2,
|
||||
all_adapters=[
|
||||
"lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"lora3=pbevan11/llama-3.1-8b-ocr-correction",
|
||||
],
|
||||
enable_lora=True,
|
||||
max_lora_rank=256,
|
||||
lora_target_modules=["all"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora1",
|
||||
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"pinned": True,
|
||||
},
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora2",
|
||||
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pinned": True,
|
||||
},
|
||||
expected_error="starvation",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora2",
|
||||
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pinned": False,
|
||||
},
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora3",
|
||||
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"pinned": False,
|
||||
},
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="lora1",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="lora3",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora3",
|
||||
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"pinned": True,
|
||||
},
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora1",
|
||||
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"pinned": True,
|
||||
},
|
||||
expected_error="starvation",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data={
|
||||
"lora_name": "lora1",
|
||||
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"pinned": False,
|
||||
},
|
||||
),
|
||||
# pinned: lora3
|
||||
# unpinned: lora1, lora2
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"lora1",
|
||||
"lora2",
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"lora1",
|
||||
"lora3",
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"lora1",
|
||||
"lora2",
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"lora1",
|
||||
"lora2",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
ALL_TESTS = (
|
||||
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
|
||||
BASIC_TESTS
|
||||
+ TARGET_MODULE_TESTS
|
||||
+ MAX_LORA_RANK_TESTS
|
||||
+ MAX_LOADED_LORAS_TESTS
|
||||
+ EVICTION_TESTS
|
||||
)
|
||||
|
||||
|
||||
@@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
pinned: bool = False,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
@@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
response = self.handle.load_lora_adapter(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
pinned=pinned,
|
||||
)
|
||||
if expected_error:
|
||||
self.testcase.assertFalse(response.success)
|
||||
self.testcase.assertIn(expected_error, response.error_message)
|
||||
self.testcase.assertFalse(
|
||||
response.success, f"Expected failure for {lora_name}, but got success."
|
||||
)
|
||||
self.testcase.assertIn(
|
||||
expected_error,
|
||||
response.error_message,
|
||||
f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
|
||||
)
|
||||
print(f"Received error as expected: {response.error_message}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.success)
|
||||
self.testcase.assertTrue(
|
||||
response.success,
|
||||
f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
|
||||
)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
self.testcase.assertEqual(
|
||||
loaded_adapters,
|
||||
self.expected_adapters,
|
||||
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
response = self.handle.unload_lora_adapter(
|
||||
lora_name=lora_name,
|
||||
)
|
||||
self.testcase.assertTrue(response.success)
|
||||
self.testcase.assertTrue(
|
||||
response.success,
|
||||
f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
|
||||
)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
self.testcase.assertEqual(
|
||||
loaded_adapters,
|
||||
self.expected_adapters,
|
||||
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
except ValueError as e:
|
||||
if expected_error:
|
||||
error_message = str(e)
|
||||
self.testcase.assertIn(expected_error, error_message)
|
||||
self.testcase.assertIn(
|
||||
expected_error,
|
||||
error_message,
|
||||
f"Expected error message to contain '{expected_error}', but got '{error_message}'",
|
||||
)
|
||||
print(f"Received error as expected: {error_message}")
|
||||
return error_message
|
||||
|
||||
raise e
|
||||
|
||||
self.testcase.assertEqual(len(response.output_strs), len(prompts))
|
||||
self.testcase.assertEqual(
|
||||
len(response.output_strs),
|
||||
len(prompts),
|
||||
f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
|
||||
)
|
||||
output = response.output_strs
|
||||
print(f"output_strs: {output}")
|
||||
|
||||
@@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
pinned: bool = False,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
@@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
|
||||
response = requests.post(
|
||||
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
||||
json={"lora_name": lora_name, "lora_path": lora_path},
|
||||
json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
|
||||
)
|
||||
if expected_error:
|
||||
self.testcase.assertEqual(response.status_code, 400)
|
||||
self.testcase.assertIn(expected_error, response.text)
|
||||
self.testcase.assertEqual(
|
||||
response.status_code,
|
||||
400,
|
||||
f"Expected error for {lora_name}, but got success.",
|
||||
)
|
||||
self.testcase.assertIn(
|
||||
expected_error,
|
||||
response.text,
|
||||
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
|
||||
)
|
||||
print(f"Received error as expected: {response.text}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
self.testcase.assertTrue(
|
||||
response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
|
||||
)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
self.testcase.assertEqual(
|
||||
loaded_adapters,
|
||||
self.expected_adapters,
|
||||
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
|
||||
json={"lora_name": lora_name},
|
||||
)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
self.testcase.assertTrue(
|
||||
response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
|
||||
)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
self.testcase.assertEqual(
|
||||
loaded_adapters,
|
||||
self.expected_adapters,
|
||||
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
},
|
||||
)
|
||||
if expected_error:
|
||||
self.testcase.assertEqual(response.status_code, 400)
|
||||
self.testcase.assertIn(expected_error, response.text)
|
||||
self.testcase.assertEqual(
|
||||
response.status_code,
|
||||
400,
|
||||
f"Expected error for forward pass, but got success: {response.text}",
|
||||
)
|
||||
self.testcase.assertIn(
|
||||
expected_error,
|
||||
response.text,
|
||||
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
|
||||
)
|
||||
output = response.text
|
||||
print(f"Received error as expected: {response.text}")
|
||||
return output
|
||||
else:
|
||||
self.testcase.assertTrue(response.ok)
|
||||
self.testcase.assertTrue(
|
||||
response.ok, f"Failed to generate text: {response.text}"
|
||||
)
|
||||
output = [r["text"] for r in response.json()]
|
||||
self.testcase.assertEqual(len(output), len(prompts))
|
||||
self.testcase.assertEqual(
|
||||
len(output),
|
||||
len(prompts),
|
||||
f"Expected {len(prompts)} outputs, but got {len(output)}",
|
||||
)
|
||||
print(f"output_strs: {output}")
|
||||
return output
|
||||
|
||||
@@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
||||
)
|
||||
if op_type == OperationType.LOAD:
|
||||
if isinstance(data, str):
|
||||
adapter_info = {
|
||||
"lora_name": data,
|
||||
"lora_path": data,
|
||||
"pinned": False,
|
||||
}
|
||||
else:
|
||||
adapter_info = data
|
||||
|
||||
result = session.load_lora_adapter(
|
||||
lora_name=data,
|
||||
lora_path=data,
|
||||
expected_error=expected_error,
|
||||
**adapter_info,
|
||||
)
|
||||
elif op_type == OperationType.UNLOAD:
|
||||
result = session.unload_lora_adapter(
|
||||
|
||||
Reference in New Issue
Block a user