Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)

This commit is contained in:
Lifu Huang
2025-07-13 18:36:01 -07:00
committed by GitHub
parent b5dd5e8741
commit e2ed9d049a
10 changed files with 840 additions and 227 deletions

View File

@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Union
from typing import Any, Iterable, List, Optional, Union
import requests
import torch
@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
is_in_ci,
popen_launch_server,
)
@@ -45,24 +46,28 @@ class OperationType(Enum):
LOAD = "load"
UNLOAD = "unload"
FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass
class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
# Operation type, can be LOAD, UNLOAD, FORWARD
type: OperationType
# Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any]
# If the operation is expected to fail, this is the error message to expect
expected_error: Optional[str] = None
@dataclass
class TestCase:
description: str
base: str
max_loras_per_batch: int
all_adapters: List[str]
initial_adapters: List[str]
op_sequence: List[Operation]
max_lora_rank: Optional[int] = None
lora_target_modules: Optional[List] = None
max_new_tokens: int = 32
@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
TEST_CASES = [
# basic test, no eviction
BASIC_TESTS = [
TestCase(
description="dynamic lora update with initial lora_paths",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
@@ -89,20 +94,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
@@ -127,11 +128,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
@@ -147,13 +146,11 @@ TEST_CASES = [
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
@@ -174,8 +171,8 @@ TEST_CASES = [
),
],
),
# Eviction
TestCase(
description="dynamic lora update with evictions",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=1,
all_adapters=[
@@ -190,20 +187,16 @@ TEST_CASES = [
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
@@ -214,11 +207,9 @@ TEST_CASES = [
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
@@ -263,6 +254,253 @@ TEST_CASES = [
],
),
]
TARGET_MODULE_TESTS = [
TestCase(
description="Test explicitly specified lora-target-modules.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
lora_target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with larger adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="Test inferred lora-target-modules - start with smaller adapter",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=64,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
],
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
None,
]
),
),
],
),
]
MAX_LORA_RANK_TESTS = [
TestCase(
description="Test explicitly specified max-lora-rank.",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
max_lora_rank=32,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
expected_error="not loaded",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"pbevan11/llama-3.1-8b-ocr-correction",
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
None,
]
),
),
],
),
TestCase(
description="test implicitly inferred max-lora-rank",
base="meta-llama/Llama-3.1-8B-Instruct",
max_loras_per_batch=3,
all_adapters=[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
],
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
),
Operation(
type=OperationType.LOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
expected_error="updating LoRA shapes",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
expected_error="not loaded",
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
[
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
"pbevan11/llama-3.1-8b-ocr-correction",
None,
]
),
),
],
),
]
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
class LoRAUpdateTestSessionMode(Enum):
@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
testcase: Optional[TestCase],
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
max_loras_per_batch: int,
max_lora_rank: Optional[int],
lora_target_modules: Optional[List[str]] = None,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
self.testcase = testcase
self.model_path = model_path
self.lora_paths = lora_paths
self.max_lora_rank = max_lora_rank
self.lora_target_modules = lora_target_modules
self.max_loras_per_batch = max_loras_per_batch
self.lora_backend = lora_backend
self.disable_cuda_graph = disable_cuda_graph
@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
# Don't suppress exceptions by default
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
model_path=self.model_path,
model_type="generation",
lora_paths=self.lora_paths,
max_lora_rank=self.max_lora_rank,
lora_target_modules=self.lora_target_modules,
lora_backend=self.lora_backend,
torch_dtype=torch.float16,
mem_fraction_static=MEM_FRACTION_STATIC,
@@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = self.handle.load_lora_adapter(
lora_name=lora_name,
lora_path=lora_path,
)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
if expected_error:
self.testcase.assertFalse(response.success)
self.testcase.assertIn(expected_error, response.error_message)
print(f"Received error as expected: {response.error_message}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.success)
loaded_adapters = set(response.loaded_adapters)
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
]
if self.disable_cuda_graph:
other_args.append("--disable-cuda-graph")
if self.max_lora_rank is not None:
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
if self.lora_target_modules is not None:
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
# launch external server
self.handle = popen_launch_server(
@@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# don't suppress exceptions
return False
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
def load_lora_adapter(
self,
lora_name: str,
lora_path: Optional[str] = None,
expected_error: Optional[str] = None,
):
"""
Load a LoRA adapter by name and path.
"""
if lora_path is None:
lora_path = lora_name
self.expected_adapters.add(lora_name)
response = requests.post(
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
json={"lora_name": lora_name, "lora_path": lora_path},
)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
print(f"Received error as expected: {response.text}")
else:
self.expected_adapters.add(lora_name)
self.testcase.assertTrue(response.ok)
loaded_adapters = set(response.json()["loaded_adapters"])
print(f"loaded_adapters: {loaded_adapters}")
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
def unload_lora_adapter(self, lora_name: str):
"""
@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
expected_error: Optional[str] = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
# Factory function to create the appropriate LoRA test session based on mode
def LoRAUpdateTestSession(
*,
testcase: Optional[TestCase],
mode: LoRAUpdateTestSessionMode,
model_path: str,
lora_paths: list[str],
max_loras_per_batch: int = 1,
lora_backend: str = "triton",
disable_cuda_graph: bool = False,
cuda_graph_max_bs: int = 4,
**kwargs: Any,
):
common_kwargs = {
"testcase": testcase,
"model_path": model_path,
"lora_paths": lora_paths,
"max_loras_per_batch": max_loras_per_batch,
"lora_backend": lora_backend,
"disable_cuda_graph": disable_cuda_graph,
"cuda_graph_max_bs": cuda_graph_max_bs,
}
if mode == LoRAUpdateTestSessionMode.ENGINE:
return LoRAUpdateEngineTestSession(**common_kwargs)
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
elif mode == LoRAUpdateTestSessionMode.SERVER:
return LoRAUpdateServerTestSession(**common_kwargs)
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
else:
raise ValueError(f"Unrecognized mode: {mode!r}")
@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
initial_adapters: List[str],
max_loras_per_batch: int,
op_sequence: List[Operation],
max_lora_rank: Optional[int] = None,
lora_target_modules: Optional[List[str]] = None,
max_new_tokens: int = 32,
) -> List[tuple]:
"""
@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
model_path=base,
lora_paths=initial_adapters,
max_loras_per_batch=max_loras_per_batch,
max_lora_rank=max_lora_rank,
lora_target_modules=lora_target_modules,
) as session:
for op in op_sequence:
op_type = op.type
data = op.data
expected_error = op.expected_error
print("-" * 100)
print(
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
result = session.load_lora_adapter(
lora_name=data,
lora_path=data,
expected_error=expected_error,
)
elif op_type == OperationType.UNLOAD:
result = session.unload_lora_adapter(
@@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase):
)
elif op_type == OperationType.FORWARD:
prompts, adapters = zip(*data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
expected_error=expected_error,
)
if not expected_error:
forward_outputs.append(result)
return forward_outputs
def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [
LoRAUpdateTestSessionMode.ENGINE,
LoRAUpdateTestSessionMode.SERVER,
]:
print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.")
print("=" * 100)
def _run_dynamic_adapter_updates(
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
):
for case_idx, test_case in enumerate(test_cases, start=1):
print("=" * 100)
print(
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
)
print("=" * 100)
print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
)
# Test dynamic loading of adapters
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
)
print(
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
)
# Test dynamic loading of adapters
dynamic_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.initial_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=test_case.op_sequence,
max_new_tokens=test_case.max_new_tokens,
max_lora_rank=test_case.max_lora_rank,
lora_target_modules=test_case.lora_target_modules,
)
# static loading
forward_ops = [
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
]
# static loading
forward_ops = [
x
for x in test_case.op_sequence
if x.type == OperationType.FORWARD and x.expected_error is None
]
print("=" * 100)
print(
f"\n--- Running static pass with {len(forward_ops)} operations ---"
)
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
max_new_tokens=test_case.max_new_tokens,
)
print("=" * 100)
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
static_output = self._run_operation_sequence(
mode=mode,
initial_adapters=test_case.all_adapters,
base=test_case.base,
max_loras_per_batch=test_case.max_loras_per_batch,
op_sequence=forward_ops,
max_new_tokens=test_case.max_new_tokens,
)
print(f"Dynamic output: {dynamic_output}")
print(f"Static output: {static_output}")
print("=" * 100)
print(f"Dynamic output: {dynamic_output}")
print(f"Static output: {static_output}")
print("=" * 100)
self.assertEqual(
len(dynamic_output),
len(static_output),
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
)
for i, (dynamic, static) in enumerate(
zip(dynamic_output, static_output), start=1
):
self.assertEqual(
len(dynamic_output),
len(static_output),
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
len(dynamic),
len(static),
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
)
for i, (dynamic, static) in enumerate(
zip(dynamic_output, static_output), start=1
):
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
d_out = d_out.strip()
s_out = s_out.strip()
self.assertEqual(
len(dynamic),
len(static),
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
d_out,
s_out,
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
)
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
d_out = d_out.strip()
s_out = s_out.strip()
self.assertEqual(
d_out,
s_out,
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
)
def test_dynamic_lora_update_engine(self):
"""
Test dynamic LoRA updates in engine mode.
"""
test_cases = ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.ENGINE,
test_cases=test_cases,
)
def test_dynamic_lora_update_server(self):
"""
Test dynamic LoRA updates in server mode.
"""
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
self._run_dynamic_adapter_updates(
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
)
if __name__ == "__main__":

View File

@@ -17,7 +17,7 @@ suites = {
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),
TestFile("models/lora/test_lora_update.py", 400),
TestFile("models/lora/test_lora_update.py", 700),
TestFile("models/test_embedding_models.py", 73),
# TestFile("models/test_clip_models.py", 52),
TestFile("models/test_encoder_embedding_models.py", 100),