Support pinning adapter via server args. (#9249)

This commit is contained in:
Lifu Huang
2025-08-20 16:25:01 -07:00
committed by GitHub
parent 24eaebeb4b
commit b0980af89f
8 changed files with 162 additions and 55 deletions

View File

@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import json
import multiprocessing as mp
import unittest
from dataclasses import dataclass
@@ -89,8 +90,35 @@ BASIC_TESTS = [
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
initial_adapters=[
# Testing 3 supported lora-path formats.
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
{
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
"pinned": False,
},
],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
]
),
),
Operation(
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
@@ -147,6 +175,10 @@ BASIC_TESTS = [
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.UNLOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
@@ -157,18 +189,12 @@ BASIC_TESTS = [
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
"pbevan11/llama-3.1-8b-ocr-correction",
]
None,
),
),
],
@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase:
*,
testcase: Optional[TestCase],
model_path: str,
lora_paths: list[str],
lora_paths: List[Union[str, dict]],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,
max_lora_rank: Optional[int],
@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase:
self.cuda_graph_max_bs = cuda_graph_max_bs
self.enable_lora = enable_lora
self.expected_adapters = set(lora_paths or [])
self.expected_adapters = set()
if self.lora_paths:
for adapter in self.lora_paths:
if isinstance(adapter, dict):
lora_name = adapter["lora_name"]
elif "=" in adapter:
lora_name = adapter.split("=")[0]
else:
lora_name = adapter
self.expected_adapters.add(lora_name)
self.handle = None # Will be set in __enter__
def __enter__(self):
@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
if self.enable_lora:
other_args.append("--enable-lora")
if self.lora_paths:
other_args.extend(["--lora-paths"] + self.lora_paths)
other_args.append("--lora-paths")
for lora_path in self.lora_paths:
if isinstance(lora_path, dict):
lora_path = json.dumps(lora_path)
other_args.append(lora_path)
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
self,
mode: LoRAUpdateTestSessionMode,
base: str,
initial_adapters: List[str],
initial_adapters: List[Union[str, dict]],
op_sequence: List[Operation],
max_loras_per_batch: int,
max_loaded_loras: Optional[int] = None,