Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)
This commit is contained in:
@@ -16,7 +16,7 @@ import multiprocessing as mp
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
@@ -45,24 +46,28 @@ class OperationType(Enum):
|
||||
LOAD = "load"
|
||||
UNLOAD = "unload"
|
||||
FORWARD = "forward"
|
||||
EXPECT_ERROR = "expect_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
|
||||
# Operation type, can be LOAD, UNLOAD, FORWARD
|
||||
type: OperationType
|
||||
# Data associated with the operation. Exact type varies depending on the operation
|
||||
data: Optional[Any]
|
||||
# If the operation is expected to fail, this is the error message to expect
|
||||
expected_error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCase:
|
||||
description: str
|
||||
base: str
|
||||
max_loras_per_batch: int
|
||||
all_adapters: List[str]
|
||||
initial_adapters: List[str]
|
||||
op_sequence: List[Operation]
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[List] = None
|
||||
max_new_tokens: int = 32
|
||||
|
||||
|
||||
@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
# basic test, no eviction
|
||||
BASIC_TESTS = [
|
||||
TestCase(
|
||||
description="dynamic lora update with initial lora_paths",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
all_adapters=[
|
||||
@@ -89,20 +94,16 @@ TEST_CASES = [
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
@@ -127,11 +128,9 @@ TEST_CASES = [
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -147,13 +146,11 @@ TEST_CASES = [
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -174,8 +171,8 @@ TEST_CASES = [
|
||||
),
|
||||
],
|
||||
),
|
||||
# Eviction
|
||||
TestCase(
|
||||
description="dynamic lora update with evictions",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=1,
|
||||
all_adapters=[
|
||||
@@ -190,20 +187,16 @@ TEST_CASES = [
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
@@ -214,11 +207,9 @@ TEST_CASES = [
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -263,6 +254,253 @@ TEST_CASES = [
|
||||
],
|
||||
),
|
||||
]
|
||||
TARGET_MODULE_TESTS = [
|
||||
TestCase(
|
||||
description="Test explicitly specified lora-target-modules.",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
lora_target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="Test inferred lora-target-modules - start with larger adapter",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="Test inferred lora-target-modules - start with smaller adapter",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
MAX_LORA_RANK_TESTS = [
|
||||
TestCase(
|
||||
description="Test explicitly specified max-lora-rank.",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=32,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||
],
|
||||
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="test implicitly inferred max-lora-rank",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||
],
|
||||
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
|
||||
|
||||
|
||||
class LoRAUpdateTestSessionMode(Enum):
|
||||
@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
|
||||
testcase: Optional[TestCase],
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
max_loras_per_batch: int = 1,
|
||||
max_loras_per_batch: int,
|
||||
max_lora_rank: Optional[int],
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
|
||||
self.testcase = testcase
|
||||
self.model_path = model_path
|
||||
self.lora_paths = lora_paths
|
||||
self.max_lora_rank = max_lora_rank
|
||||
self.lora_target_modules = lora_target_modules
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.lora_backend = lora_backend
|
||||
self.disable_cuda_graph = disable_cuda_graph
|
||||
@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
|
||||
# Don't suppress exceptions by default
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
model_path=self.model_path,
|
||||
model_type="generation",
|
||||
lora_paths=self.lora_paths,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
lora_target_modules=self.lora_target_modules,
|
||||
lora_backend=self.lora_backend,
|
||||
torch_dtype=torch.float16,
|
||||
mem_fraction_static=MEM_FRACTION_STATIC,
|
||||
@@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
# don't suppress exceptions
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
if lora_path is None:
|
||||
lora_path = lora_name
|
||||
|
||||
self.expected_adapters.add(lora_name)
|
||||
|
||||
response = self.handle.load_lora_adapter(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
self.testcase.assertTrue(response.success)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
if expected_error:
|
||||
self.testcase.assertFalse(response.success)
|
||||
self.testcase.assertIn(expected_error, response.error_message)
|
||||
print(f"Received error as expected: {response.error_message}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.success)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
]
|
||||
if self.disable_cuda_graph:
|
||||
other_args.append("--disable-cuda-graph")
|
||||
if self.max_lora_rank is not None:
|
||||
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
||||
if self.lora_target_modules is not None:
|
||||
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
||||
|
||||
# launch external server
|
||||
self.handle = popen_launch_server(
|
||||
@@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
# don't suppress exceptions
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
if lora_path is None:
|
||||
lora_path = lora_name
|
||||
|
||||
self.expected_adapters.add(lora_name)
|
||||
|
||||
response = requests.post(
|
||||
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
||||
json={"lora_name": lora_name, "lora_path": lora_path},
|
||||
)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
if expected_error:
|
||||
self.testcase.assertEqual(response.status_code, 400)
|
||||
self.testcase.assertIn(expected_error, response.text)
|
||||
print(f"Received error as expected: {response.text}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
|
||||
# Factory function to create the appropriate LoRA test session based on mode
|
||||
def LoRAUpdateTestSession(
|
||||
*,
|
||||
testcase: Optional[TestCase],
|
||||
mode: LoRAUpdateTestSessionMode,
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
max_loras_per_batch: int = 1,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
**kwargs: Any,
|
||||
):
|
||||
common_kwargs = {
|
||||
"testcase": testcase,
|
||||
"model_path": model_path,
|
||||
"lora_paths": lora_paths,
|
||||
"max_loras_per_batch": max_loras_per_batch,
|
||||
"lora_backend": lora_backend,
|
||||
"disable_cuda_graph": disable_cuda_graph,
|
||||
"cuda_graph_max_bs": cuda_graph_max_bs,
|
||||
}
|
||||
|
||||
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
||||
return LoRAUpdateEngineTestSession(**common_kwargs)
|
||||
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
|
||||
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
||||
return LoRAUpdateServerTestSession(**common_kwargs)
|
||||
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {mode!r}")
|
||||
|
||||
@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
initial_adapters: List[str],
|
||||
max_loras_per_batch: int,
|
||||
op_sequence: List[Operation],
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 32,
|
||||
) -> List[tuple]:
|
||||
"""
|
||||
@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
model_path=base,
|
||||
lora_paths=initial_adapters,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
) as session:
|
||||
for op in op_sequence:
|
||||
op_type = op.type
|
||||
data = op.data
|
||||
expected_error = op.expected_error
|
||||
print("-" * 100)
|
||||
print(
|
||||
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
||||
@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
result = session.load_lora_adapter(
|
||||
lora_name=data,
|
||||
lora_path=data,
|
||||
expected_error=expected_error,
|
||||
)
|
||||
elif op_type == OperationType.UNLOAD:
|
||||
result = session.unload_lora_adapter(
|
||||
@@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
)
|
||||
elif op_type == OperationType.FORWARD:
|
||||
prompts, adapters = zip(*data)
|
||||
result = session.forward(
|
||||
prompts=list(prompts),
|
||||
lora_paths=list(adapters),
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
forward_outputs.append(result)
|
||||
elif op_type == OperationType.EXPECT_ERROR:
|
||||
input_data, expected_error = data
|
||||
prompts, adapters = zip(*input_data)
|
||||
result = session.forward(
|
||||
prompts=list(prompts),
|
||||
lora_paths=list(adapters),
|
||||
max_new_tokens=max_new_tokens,
|
||||
expected_error=expected_error,
|
||||
)
|
||||
if not expected_error:
|
||||
forward_outputs.append(result)
|
||||
|
||||
return forward_outputs
|
||||
|
||||
def test_dynamic_adapter_updates(self):
|
||||
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
||||
for mode in [
|
||||
LoRAUpdateTestSessionMode.ENGINE,
|
||||
LoRAUpdateTestSessionMode.SERVER,
|
||||
]:
|
||||
print("=" * 100)
|
||||
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
||||
print("=" * 100)
|
||||
def _run_dynamic_adapter_updates(
|
||||
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
|
||||
):
|
||||
for case_idx, test_case in enumerate(test_cases, start=1):
|
||||
print("=" * 100)
|
||||
print(
|
||||
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
|
||||
)
|
||||
print("=" * 100)
|
||||
|
||||
print(
|
||||
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||
)
|
||||
# Test dynamic loading of adapters
|
||||
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
|
||||
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
|
||||
dynamic_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.initial_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=test_case.op_sequence,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
print(
|
||||
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||
)
|
||||
# Test dynamic loading of adapters
|
||||
dynamic_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.initial_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=test_case.op_sequence,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
max_lora_rank=test_case.max_lora_rank,
|
||||
lora_target_modules=test_case.lora_target_modules,
|
||||
)
|
||||
|
||||
# static loading
|
||||
forward_ops = [
|
||||
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
|
||||
]
|
||||
# static loading
|
||||
forward_ops = [
|
||||
x
|
||||
for x in test_case.op_sequence
|
||||
if x.type == OperationType.FORWARD and x.expected_error is None
|
||||
]
|
||||
|
||||
print("=" * 100)
|
||||
print(
|
||||
f"\n--- Running static pass with {len(forward_ops)} operations ---"
|
||||
)
|
||||
static_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.all_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=forward_ops,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
print("=" * 100)
|
||||
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
||||
static_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.all_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=forward_ops,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
|
||||
print(f"Dynamic output: {dynamic_output}")
|
||||
print(f"Static output: {static_output}")
|
||||
print("=" * 100)
|
||||
print(f"Dynamic output: {dynamic_output}")
|
||||
print(f"Static output: {static_output}")
|
||||
print("=" * 100)
|
||||
self.assertEqual(
|
||||
len(dynamic_output),
|
||||
len(static_output),
|
||||
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||
)
|
||||
for i, (dynamic, static) in enumerate(
|
||||
zip(dynamic_output, static_output), start=1
|
||||
):
|
||||
self.assertEqual(
|
||||
len(dynamic_output),
|
||||
len(static_output),
|
||||
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||
len(dynamic),
|
||||
len(static),
|
||||
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||
)
|
||||
for i, (dynamic, static) in enumerate(
|
||||
zip(dynamic_output, static_output), start=1
|
||||
):
|
||||
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||
d_out = d_out.strip()
|
||||
s_out = s_out.strip()
|
||||
self.assertEqual(
|
||||
len(dynamic),
|
||||
len(static),
|
||||
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||
d_out,
|
||||
s_out,
|
||||
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||
)
|
||||
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||
d_out = d_out.strip()
|
||||
s_out = s_out.strip()
|
||||
self.assertEqual(
|
||||
d_out,
|
||||
s_out,
|
||||
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||
)
|
||||
|
||||
def test_dynamic_lora_update_engine(self):
|
||||
"""
|
||||
Test dynamic LoRA updates in engine mode.
|
||||
"""
|
||||
test_cases = ALL_TESTS
|
||||
self._run_dynamic_adapter_updates(
|
||||
mode=LoRAUpdateTestSessionMode.ENGINE,
|
||||
test_cases=test_cases,
|
||||
)
|
||||
|
||||
def test_dynamic_lora_update_server(self):
|
||||
"""
|
||||
Test dynamic LoRA updates in server mode.
|
||||
"""
|
||||
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
|
||||
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
|
||||
|
||||
self._run_dynamic_adapter_updates(
|
||||
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,7 +17,7 @@ suites = {
|
||||
TestFile("models/lora/test_lora_backend.py", 99),
|
||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||
TestFile("models/lora/test_lora_update.py", 400),
|
||||
TestFile("models/lora/test_lora_update.py", 700),
|
||||
TestFile("models/test_embedding_models.py", 73),
|
||||
# TestFile("models/test_clip_models.py", 52),
|
||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||
|
||||
Reference in New Issue
Block a user