Support pinning adapter via server args. (#9249)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
@@ -89,8 +90,35 @@ BASIC_TESTS = [
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
],
|
||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||
initial_adapters=[
|
||||
# Testing 3 supported lora-path formats.
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
{
|
||||
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"pinned": False,
|
||||
},
|
||||
],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
@@ -147,6 +175,10 @@ BASIC_TESTS = [
|
||||
type=OperationType.UNLOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.UNLOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
@@ -157,18 +189,12 @@ BASIC_TESTS = [
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
]
|
||||
None,
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase:
|
||||
*,
|
||||
testcase: Optional[TestCase],
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
lora_paths: List[Union[str, dict]],
|
||||
max_loras_per_batch: int,
|
||||
max_loaded_loras: Optional[int] = None,
|
||||
max_lora_rank: Optional[int],
|
||||
@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase:
|
||||
self.cuda_graph_max_bs = cuda_graph_max_bs
|
||||
self.enable_lora = enable_lora
|
||||
|
||||
self.expected_adapters = set(lora_paths or [])
|
||||
self.expected_adapters = set()
|
||||
if self.lora_paths:
|
||||
for adapter in self.lora_paths:
|
||||
if isinstance(adapter, dict):
|
||||
lora_name = adapter["lora_name"]
|
||||
elif "=" in adapter:
|
||||
lora_name = adapter.split("=")[0]
|
||||
else:
|
||||
lora_name = adapter
|
||||
self.expected_adapters.add(lora_name)
|
||||
|
||||
self.handle = None # Will be set in __enter__
|
||||
|
||||
def __enter__(self):
|
||||
@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
if self.enable_lora:
|
||||
other_args.append("--enable-lora")
|
||||
if self.lora_paths:
|
||||
other_args.extend(["--lora-paths"] + self.lora_paths)
|
||||
other_args.append("--lora-paths")
|
||||
for lora_path in self.lora_paths:
|
||||
if isinstance(lora_path, dict):
|
||||
lora_path = json.dumps(lora_path)
|
||||
other_args.append(lora_path)
|
||||
if self.disable_cuda_graph:
|
||||
other_args.append("--disable-cuda-graph")
|
||||
if self.max_lora_rank is not None:
|
||||
@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
self,
|
||||
mode: LoRAUpdateTestSessionMode,
|
||||
base: str,
|
||||
initial_adapters: List[str],
|
||||
initial_adapters: List[Union[str, dict]],
|
||||
op_sequence: List[Operation],
|
||||
max_loras_per_batch: int,
|
||||
max_loaded_loras: Optional[int] = None,
|
||||
|
||||
Reference in New Issue
Block a user