Support GPU pinning for LoRA (#8697)

This commit is contained in:
Lifu Huang
2025-08-06 19:39:45 -07:00
committed by GitHub
parent 6ad6c8c9e6
commit 6210e2c4f0
13 changed files with 425 additions and 134 deletions

View File

@@ -231,88 +231,6 @@ BASIC_TESTS = [
),
],
),
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
],
),
]
TARGET_MODULE_TESTS = [
TestCase(
@@ -593,9 +511,135 @@ MAX_LOADED_LORAS_TESTS = [
],
),
]
EVICTION_TESTS = [
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=2,
all_adapters=[
"lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
"lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"lora3=pbevan11/llama-3.1-8b-ocr-correction",
],
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
op_sequence=[
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora2",
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pinned": False,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
),
Operation(
type=OperationType.UNLOAD,
data="lora1",
),
Operation(
type=OperationType.UNLOAD,
data="lora3",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora3",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": True,
},
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": True,
},
expected_error="starvation",
),
Operation(
type=OperationType.LOAD,
data={
"lora_name": "lora1",
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pinned": False,
},
),
# pinned: lora3
# unpinned: lora1, lora2
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora3",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
]
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"lora1",
"lora2",
None,
]
),
),
],
),
]
ALL_TESTS = (
BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS + MAX_LOADED_LORAS_TESTS
BASIC_TESTS
+ TARGET_MODULE_TESTS
+ MAX_LORA_RANK_TESTS
+ MAX_LOADED_LORAS_TESTS
+ EVICTION_TESTS
)
@@ -714,6 +758,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
pinned: bool = False,
):
"""
Load a LoRA adapter by name and path.
@@ -724,17 +769,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.load_lora_adapter(
lora_name=lora_name,
lora_path=lora_path,
pinned=pinned,
)
if expected_error:
self.testcase.assertFalse(response.success)
self.testcase.assertIn(expected_error, response.error_message)
self.testcase.assertFalse(
response.success, f"Expected failure for {lora_name}, but got success."
)
self.testcase.assertIn(
expected_error,
response.error_message,
f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
)
print(f"Received error as expected: {response.error_message}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success)
self.testcase.assertTrue(
response.success,
f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str):
"""
@@ -745,11 +804,18 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
response = self.handle.unload_lora_adapter(
lora_name=lora_name,
)
self.testcase.assertTrue(response.success)
self.testcase.assertTrue(
response.success,
f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward(
self,
@@ -770,13 +836,21 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
except ValueError as e:
if expected_error:
error_message = str(e)
self.testcase.assertIn(expected_error, error_message)
self.testcase.assertIn(
expected_error,
error_message,
f"Expected error message to contain '{expected_error}', but got '{error_message}'",
)
print(f"Received error as expected: {error_message}")
return error_message
raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts))
self.testcase.assertEqual(
len(response.output_strs),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
)
output = response.output_strs
print(f"output_strs: {output}")
@@ -837,6 +911,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
pinned: bool = False,
):
"""
Load a LoRA adapter by name and path.
@@ -846,18 +921,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path},
json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
self.testcase.assertEqual(
response.status_code,
400,
f"Expected error for {lora_name}, but got success.",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
print(f"Received error as expected: {response.text}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def unload_lora_adapter(self, lora_name: str):
"""
@@ -869,11 +958,17 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
json={"lora_name": lora_name},
)
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
self.testcase.assertEqual(
loaded_adapters,
self.expected_adapters,
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
)
def forward(
self,
@@ -898,15 +993,29 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
},
)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
self.testcase.assertEqual(
response.status_code,
400,
f"Expected error for forward pass, but got success: {response.text}",
)
self.testcase.assertIn(
expected_error,
response.text,
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
)
output = response.text
print(f"Received error as expected: {response.text}")
return output
else:
self.testcase.assertTrue(response.ok)
self.testcase.assertTrue(
response.ok, f"Failed to generate text: {response.text}"
)
output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts))
self.testcase.assertEqual(
len(output),
len(prompts),
f"Expected {len(prompts)} outputs, but got {len(output)}",
)
print(f"output_strs: {output}")
return output
@@ -974,10 +1083,18 @@ class TestLoRADynamicUpdate(CustomTestCase):
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
)
if op_type == OperationType.LOAD:
if isinstance(data, str):
adapter_info = {
"lora_name": data,
"lora_path": data,
"pinned": False,
}
else:
adapter_info = data
result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
expected_error=expected_error,
**adapter_info,
)
elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter(