1305 lines
45 KiB
Python
1305 lines
45 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
import json
|
|
import multiprocessing as mp
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, Iterable, List, Optional, Union
|
|
|
|
import requests
|
|
import torch
|
|
|
|
from sglang.srt.utils import kill_process_tree
|
|
from sglang.test.runners import SRTRunner
|
|
from sglang.test.test_utils import (
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
DEFAULT_URL_FOR_TEST,
|
|
CustomTestCase,
|
|
is_in_ci,
|
|
popen_launch_server,
|
|
)
|
|
|
|
PROMPTS = [
|
|
"SGL is a",
|
|
"AI is a field of computer science focused on",
|
|
"Computer science is the study of",
|
|
"Write a short story.",
|
|
"What are the main components of a computer?",
|
|
]
|
|
|
|
MEM_FRACTION_STATIC = 0.8
|
|
|
|
|
|
class OperationType(Enum):
|
|
LOAD = "load"
|
|
UNLOAD = "unload"
|
|
FORWARD = "forward"
|
|
|
|
|
|
@dataclass
|
|
class Operation:
|
|
# Operation type, can be LOAD, UNLOAD, FORWARD
|
|
type: OperationType
|
|
# Data associated with the operation. Exact type varies depending on the operation
|
|
data: Optional[Any]
|
|
# If the operation is expected to fail, this is the error message to expect
|
|
expected_error: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class TestCase:
|
|
description: str
|
|
base: str
|
|
max_loras_per_batch: int
|
|
all_adapters: List[str]
|
|
op_sequence: List[Operation]
|
|
initial_adapters: Optional[List[str]] = None
|
|
enable_lora: Optional[bool] = None
|
|
max_lora_rank: Optional[int] = None
|
|
lora_target_modules: Optional[List] = None
|
|
max_new_tokens: int = 32
|
|
max_loaded_loras: Optional[int] = None
|
|
|
|
|
|
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
|
if not isinstance(adapters, list):
|
|
adapters = [adapters]
|
|
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
|
|
|
|
|
BASIC_TESTS = [
|
|
TestCase(
|
|
description="dynamic lora update with initial lora_paths",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
all_adapters=[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
],
|
|
initial_adapters=[
|
|
# Testing 3 supported lora-path formats.
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
{
|
|
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
|
|
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
|
"pinned": False,
|
|
},
|
|
],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
None,
|
|
),
|
|
),
|
|
],
|
|
),
|
|
TestCase(
|
|
description="dynamic lora update without initial lora_paths",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
enable_lora=True,
|
|
max_lora_rank=256,
|
|
lora_target_modules=["all"],
|
|
max_loras_per_batch=4,
|
|
all_adapters=[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
None,
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(None),
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
]
|
|
TARGET_MODULE_TESTS = [
|
|
TestCase(
|
|
description="Test explicitly specified lora-target-modules.",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
lora_target_modules=[
|
|
"q_proj",
|
|
"k_proj",
|
|
"v_proj",
|
|
"o_proj",
|
|
"gate_proj",
|
|
"up_proj",
|
|
"down_proj",
|
|
],
|
|
max_lora_rank=64,
|
|
all_adapters=[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
|
],
|
|
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
TestCase(
|
|
description="Test inferred lora-target-modules - start with larger adapter",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
max_lora_rank=64,
|
|
all_adapters=[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
|
],
|
|
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
TestCase(
|
|
description="Test inferred lora-target-modules - start with smaller adapter",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
max_lora_rank=64,
|
|
all_adapters=[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
|
],
|
|
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
expected_error="incompatible",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
]
|
|
MAX_LORA_RANK_TESTS = [
|
|
TestCase(
|
|
description="Test explicitly specified max-lora-rank.",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
max_lora_rank=32,
|
|
all_adapters=[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
|
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
|
],
|
|
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
expected_error="incompatible",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
TestCase(
|
|
description="test implicitly inferred max-lora-rank",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=3,
|
|
all_adapters=[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
|
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
|
],
|
|
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
expected_error="incompatible",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
|
expected_error="not loaded",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
]
|
|
MAX_LOADED_LORAS_TESTS = [
|
|
TestCase(
|
|
description="Test max_loaded_loras limit",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=2,
|
|
max_loaded_loras=2,
|
|
all_adapters=[
|
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
],
|
|
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
expected_error="Maximum number of loaded LoRA adapters",
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
|
),
|
|
],
|
|
),
|
|
]
|
|
EVICTION_TESTS = [
|
|
TestCase(
|
|
description="dynamic lora update with evictions",
|
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
|
max_loras_per_batch=2,
|
|
all_adapters=[
|
|
"lora1=philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"lora2=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"lora3=pbevan11/llama-3.1-8b-ocr-correction",
|
|
],
|
|
enable_lora=True,
|
|
max_lora_rank=256,
|
|
lora_target_modules=["all"],
|
|
op_sequence=[
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora1",
|
|
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"pinned": True,
|
|
},
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora2",
|
|
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pinned": True,
|
|
},
|
|
expected_error="starvation",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora2",
|
|
"lora_path": "Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
|
"pinned": False,
|
|
},
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora3",
|
|
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
|
"pinned": False,
|
|
},
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="lora1",
|
|
),
|
|
Operation(
|
|
type=OperationType.UNLOAD,
|
|
data="lora3",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora3",
|
|
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
|
"pinned": True,
|
|
},
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora1",
|
|
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"pinned": True,
|
|
},
|
|
expected_error="starvation",
|
|
),
|
|
Operation(
|
|
type=OperationType.LOAD,
|
|
data={
|
|
"lora_name": "lora1",
|
|
"lora_path": "philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
"pinned": False,
|
|
},
|
|
),
|
|
# pinned: lora3
|
|
# unpinned: lora1, lora2
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"lora1",
|
|
"lora2",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"lora1",
|
|
"lora3",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"lora1",
|
|
"lora2",
|
|
]
|
|
),
|
|
),
|
|
Operation(
|
|
type=OperationType.FORWARD,
|
|
data=create_batch_data(
|
|
[
|
|
"lora1",
|
|
"lora2",
|
|
None,
|
|
]
|
|
),
|
|
),
|
|
],
|
|
),
|
|
]
|
|
|
|
ALL_TESTS = (
|
|
BASIC_TESTS
|
|
+ TARGET_MODULE_TESTS
|
|
+ MAX_LORA_RANK_TESTS
|
|
+ MAX_LOADED_LORAS_TESTS
|
|
+ EVICTION_TESTS
|
|
)
|
|
|
|
|
|
class LoRAUpdateTestSessionMode(Enum):
|
|
ENGINE = "engine"
|
|
SERVER = "server"
|
|
|
|
|
|
class LoRAUpdateTestSessionBase:
|
|
"""
|
|
Base context manager for testing LoRA adapters.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
testcase: Optional[TestCase],
|
|
model_path: str,
|
|
lora_paths: List[Union[str, dict]],
|
|
max_loras_per_batch: int,
|
|
max_loaded_loras: Optional[int] = None,
|
|
max_lora_rank: Optional[int],
|
|
enable_lora: Optional[bool] = None,
|
|
lora_target_modules: Optional[List[str]] = None,
|
|
lora_backend: str = "triton",
|
|
disable_cuda_graph: bool = False,
|
|
cuda_graph_max_bs: int = 4,
|
|
):
|
|
self.testcase = testcase
|
|
self.model_path = model_path
|
|
self.lora_paths = lora_paths
|
|
self.max_lora_rank = max_lora_rank
|
|
self.lora_target_modules = lora_target_modules
|
|
self.max_loras_per_batch = max_loras_per_batch
|
|
self.max_loaded_loras = max_loaded_loras
|
|
self.lora_backend = lora_backend
|
|
self.disable_cuda_graph = disable_cuda_graph
|
|
self.cuda_graph_max_bs = cuda_graph_max_bs
|
|
self.enable_lora = enable_lora
|
|
|
|
self.expected_adapters = set()
|
|
if self.lora_paths:
|
|
for adapter in self.lora_paths:
|
|
if isinstance(adapter, dict):
|
|
lora_name = adapter["lora_name"]
|
|
elif "=" in adapter:
|
|
lora_name = adapter.split("=")[0]
|
|
else:
|
|
lora_name = adapter
|
|
self.expected_adapters.add(lora_name)
|
|
|
|
self.handle = None # Will be set in __enter__
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
# Don't suppress exceptions by default
|
|
return False
|
|
|
|
def load_lora_adapter(
|
|
self,
|
|
lora_name: str,
|
|
lora_path: Optional[str] = None,
|
|
expected_error: Optional[str] = None,
|
|
):
|
|
"""
|
|
Load a LoRA adapter by name and path.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement load_lora_adapter")
|
|
|
|
def unload_lora_adapter(self, lora_name: str):
|
|
"""
|
|
Unload a LoRA adapter by name.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement unload_lora_adapter")
|
|
|
|
def forward(
|
|
self,
|
|
prompts: List[str],
|
|
lora_paths: List[str],
|
|
max_new_tokens: int = 32,
|
|
expected_error: Optional[str] = None,
|
|
):
|
|
"""
|
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
|
"""
|
|
raise NotImplementedError("Subclasses must implement forward")
|
|
|
|
|
|
class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|
"""
|
|
Context manager for testing LoRA adapters with in-process engine.
|
|
"""
|
|
|
|
def __enter__(self):
|
|
# in-process runner
|
|
self.handle = SRTRunner(
|
|
model_path=self.model_path,
|
|
model_type="generation",
|
|
lora_paths=self.lora_paths,
|
|
max_lora_rank=self.max_lora_rank,
|
|
lora_target_modules=self.lora_target_modules,
|
|
lora_backend=self.lora_backend,
|
|
torch_dtype=torch.float16,
|
|
mem_fraction_static=MEM_FRACTION_STATIC,
|
|
max_loras_per_batch=self.max_loras_per_batch,
|
|
max_loaded_loras=self.max_loaded_loras,
|
|
disable_cuda_graph=self.disable_cuda_graph,
|
|
cuda_graph_max_bs=self.cuda_graph_max_bs,
|
|
enable_lora=self.enable_lora,
|
|
)
|
|
self.handle.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.handle is not None:
|
|
# delegate cleanup to SRTRunner
|
|
return self.handle.__exit__(exc_type, exc_val, exc_tb)
|
|
# don't suppress exceptions
|
|
return False
|
|
|
|
def load_lora_adapter(
|
|
self,
|
|
lora_name: str,
|
|
lora_path: Optional[str] = None,
|
|
expected_error: Optional[str] = None,
|
|
pinned: bool = False,
|
|
):
|
|
"""
|
|
Load a LoRA adapter by name and path.
|
|
"""
|
|
if lora_path is None:
|
|
lora_path = lora_name
|
|
|
|
response = self.handle.load_lora_adapter(
|
|
lora_name=lora_name,
|
|
lora_path=lora_path,
|
|
pinned=pinned,
|
|
)
|
|
if expected_error:
|
|
self.testcase.assertFalse(
|
|
response.success, f"Expected failure for {lora_name}, but got success."
|
|
)
|
|
self.testcase.assertIn(
|
|
expected_error,
|
|
response.error_message,
|
|
f"Expected error message to contain '{expected_error}', but got '{response.error_message}'",
|
|
)
|
|
print(f"Received error as expected: {response.error_message}")
|
|
else:
|
|
self.expected_adapters.add(lora_name)
|
|
self.testcase.assertTrue(
|
|
response.success,
|
|
f"Failed to load LoRA adapter {lora_name}: {response.error_message}",
|
|
)
|
|
loaded_adapters = set(response.loaded_adapters)
|
|
print(f"loaded_adapters: {loaded_adapters}")
|
|
self.testcase.assertEqual(
|
|
loaded_adapters,
|
|
self.expected_adapters,
|
|
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
|
)
|
|
|
|
def unload_lora_adapter(self, lora_name: str):
|
|
"""
|
|
Unload a LoRA adapter by name.
|
|
"""
|
|
self.expected_adapters.remove(lora_name)
|
|
|
|
response = self.handle.unload_lora_adapter(
|
|
lora_name=lora_name,
|
|
)
|
|
self.testcase.assertTrue(
|
|
response.success,
|
|
f"Failed to unload LoRA adapter {lora_name}: {response.error_message}",
|
|
)
|
|
loaded_adapters = set(response.loaded_adapters)
|
|
|
|
print(f"loaded_adapters: {loaded_adapters}")
|
|
self.testcase.assertEqual(
|
|
loaded_adapters,
|
|
self.expected_adapters,
|
|
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
prompts: List[str],
|
|
lora_paths: List[str],
|
|
max_new_tokens: int = 32,
|
|
expected_error: Optional[str] = None,
|
|
):
|
|
"""
|
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
|
"""
|
|
try:
|
|
response = self.handle.batch_forward(
|
|
prompts=prompts,
|
|
lora_paths=lora_paths,
|
|
max_new_tokens=max_new_tokens,
|
|
)
|
|
except ValueError as e:
|
|
if expected_error:
|
|
error_message = str(e)
|
|
self.testcase.assertIn(
|
|
expected_error,
|
|
error_message,
|
|
f"Expected error message to contain '{expected_error}', but got '{error_message}'",
|
|
)
|
|
print(f"Received error as expected: {error_message}")
|
|
return error_message
|
|
|
|
raise e
|
|
|
|
self.testcase.assertEqual(
|
|
len(response.output_strs),
|
|
len(prompts),
|
|
f"Expected {len(prompts)} outputs, but got {len(response.output_strs)}",
|
|
)
|
|
output = response.output_strs
|
|
print(f"output_strs: {output}")
|
|
|
|
return output
|
|
|
|
|
|
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|
"""
|
|
Context manager for testing LoRA adapters with standalone server.
|
|
"""
|
|
|
|
def __enter__(self):
|
|
other_args = [
|
|
"--cuda-graph-max-bs",
|
|
str(self.cuda_graph_max_bs),
|
|
"--max-loras-per-batch",
|
|
str(self.max_loras_per_batch),
|
|
"--lora-backend",
|
|
self.lora_backend,
|
|
"--random-seed",
|
|
"42",
|
|
"--max-running-request",
|
|
"1",
|
|
"--mem-fraction-static",
|
|
str(MEM_FRACTION_STATIC),
|
|
]
|
|
if self.enable_lora:
|
|
other_args.append("--enable-lora")
|
|
if self.lora_paths:
|
|
other_args.append("--lora-paths")
|
|
for lora_path in self.lora_paths:
|
|
if isinstance(lora_path, dict):
|
|
lora_path = json.dumps(lora_path)
|
|
other_args.append(lora_path)
|
|
if self.disable_cuda_graph:
|
|
other_args.append("--disable-cuda-graph")
|
|
if self.max_lora_rank is not None:
|
|
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
|
if self.lora_target_modules is not None:
|
|
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
|
if self.max_loaded_loras is not None:
|
|
other_args.extend(["--max-loaded-loras", str(self.max_loaded_loras)])
|
|
|
|
# launch external server
|
|
self.handle = popen_launch_server(
|
|
self.model_path,
|
|
DEFAULT_URL_FOR_TEST,
|
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
|
other_args=other_args,
|
|
)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.handle is not None:
|
|
kill_process_tree(self.handle.pid)
|
|
# don't suppress exceptions
|
|
return False
|
|
|
|
def load_lora_adapter(
|
|
self,
|
|
lora_name: str,
|
|
lora_path: Optional[str] = None,
|
|
expected_error: Optional[str] = None,
|
|
pinned: bool = False,
|
|
):
|
|
"""
|
|
Load a LoRA adapter by name and path.
|
|
"""
|
|
if lora_path is None:
|
|
lora_path = lora_name
|
|
|
|
response = requests.post(
|
|
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
|
json={"lora_name": lora_name, "lora_path": lora_path, "pinned": pinned},
|
|
)
|
|
if expected_error:
|
|
self.testcase.assertEqual(
|
|
response.status_code,
|
|
400,
|
|
f"Expected error for {lora_name}, but got success.",
|
|
)
|
|
self.testcase.assertIn(
|
|
expected_error,
|
|
response.text,
|
|
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
|
|
)
|
|
print(f"Received error as expected: {response.text}")
|
|
else:
|
|
self.expected_adapters.add(lora_name)
|
|
self.testcase.assertTrue(
|
|
response.ok, f"Failed to load LoRA adapter {lora_name}: {response.text}"
|
|
)
|
|
loaded_adapters = set(response.json()["loaded_adapters"])
|
|
print(f"loaded_adapters: {loaded_adapters}")
|
|
self.testcase.assertEqual(
|
|
loaded_adapters,
|
|
self.expected_adapters,
|
|
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
|
)
|
|
|
|
def unload_lora_adapter(self, lora_name: str):
|
|
"""
|
|
Unload a LoRA adapter by name.
|
|
"""
|
|
self.expected_adapters.remove(lora_name)
|
|
|
|
response = requests.post(
|
|
DEFAULT_URL_FOR_TEST + "/unload_lora_adapter",
|
|
json={"lora_name": lora_name},
|
|
)
|
|
self.testcase.assertTrue(
|
|
response.ok, f"Failed to unload LoRA adapter {lora_name}: {response.text}"
|
|
)
|
|
loaded_adapters = set(response.json()["loaded_adapters"])
|
|
|
|
print(f"loaded_adapters: {loaded_adapters}")
|
|
self.testcase.assertEqual(
|
|
loaded_adapters,
|
|
self.expected_adapters,
|
|
f"Expected loaded adapters to be {self.expected_adapters}, but got {loaded_adapters}",
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
prompts: List[str],
|
|
lora_paths: List[str],
|
|
max_new_tokens: int = 32,
|
|
expected_error: Optional[str] = None,
|
|
):
|
|
"""
|
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
|
"""
|
|
response = requests.post(
|
|
DEFAULT_URL_FOR_TEST + "/generate",
|
|
json={
|
|
"text": prompts,
|
|
"lora_path": lora_paths,
|
|
"sampling_params": {
|
|
"temperature": 0,
|
|
"top_k": 1,
|
|
"max_new_tokens": max_new_tokens,
|
|
},
|
|
},
|
|
)
|
|
if expected_error:
|
|
self.testcase.assertEqual(
|
|
response.status_code,
|
|
400,
|
|
f"Expected error for forward pass, but got success: {response.text}",
|
|
)
|
|
self.testcase.assertIn(
|
|
expected_error,
|
|
response.text,
|
|
f"Expected error message to contain '{expected_error}', but got '{response.text}'",
|
|
)
|
|
output = response.text
|
|
print(f"Received error as expected: {response.text}")
|
|
return output
|
|
else:
|
|
self.testcase.assertTrue(
|
|
response.ok, f"Failed to generate text: {response.text}"
|
|
)
|
|
output = [r["text"] for r in response.json()]
|
|
self.testcase.assertEqual(
|
|
len(output),
|
|
len(prompts),
|
|
f"Expected {len(prompts)} outputs, but got {len(output)}",
|
|
)
|
|
print(f"output_strs: {output}")
|
|
return output
|
|
|
|
|
|
# Factory function to create the appropriate LoRA test session based on mode
|
|
def LoRAUpdateTestSession(
|
|
testcase: Optional[TestCase],
|
|
mode: LoRAUpdateTestSessionMode,
|
|
**kwargs: Any,
|
|
):
|
|
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
|
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
|
|
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
|
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
|
|
else:
|
|
raise ValueError(f"Unrecognized mode: {mode!r}")
|
|
|
|
|
|
class TestLoRADynamicUpdate(CustomTestCase):
|
|
"""
|
|
This test case verifies that the SRT runner can dynamically load and unload LoRA adapters
|
|
during a sequence of operations, and that the outputs of forward passes with dynamically loaded
|
|
adapters match the outputs of forward passes with statically loaded adapters.
|
|
"""
|
|
|
|
def _repeat_each(lst, n):
|
|
return [x for x in lst for _ in range(n)]
|
|
|
|
def _run_operation_sequence(
|
|
self,
|
|
mode: LoRAUpdateTestSessionMode,
|
|
base: str,
|
|
initial_adapters: List[Union[str, dict]],
|
|
op_sequence: List[Operation],
|
|
max_loras_per_batch: int,
|
|
max_loaded_loras: Optional[int] = None,
|
|
enable_lora: Optional[bool] = None,
|
|
max_lora_rank: Optional[int] = None,
|
|
lora_target_modules: Optional[List[str]] = None,
|
|
max_new_tokens: int = 32,
|
|
) -> List[tuple]:
|
|
"""
|
|
Runs a sequence of operations on the SRT runner, including loading and unloading LoRA adapters,
|
|
and performing forward passes with the current set of loaded adapters.
|
|
"""
|
|
|
|
forward_outputs = []
|
|
with LoRAUpdateTestSession(
|
|
testcase=self,
|
|
mode=mode,
|
|
model_path=base,
|
|
lora_paths=initial_adapters,
|
|
max_loras_per_batch=max_loras_per_batch,
|
|
max_loaded_loras=max_loaded_loras,
|
|
max_lora_rank=max_lora_rank,
|
|
lora_target_modules=lora_target_modules,
|
|
enable_lora=enable_lora,
|
|
) as session:
|
|
for op in op_sequence:
|
|
op_type = op.type
|
|
data = op.data
|
|
expected_error = op.expected_error
|
|
print("-" * 100)
|
|
print(
|
|
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
|
)
|
|
if op_type == OperationType.LOAD:
|
|
if isinstance(data, str):
|
|
adapter_info = {
|
|
"lora_name": data,
|
|
"lora_path": data,
|
|
"pinned": False,
|
|
}
|
|
else:
|
|
adapter_info = data
|
|
|
|
result = session.load_lora_adapter(
|
|
expected_error=expected_error,
|
|
**adapter_info,
|
|
)
|
|
elif op_type == OperationType.UNLOAD:
|
|
result = session.unload_lora_adapter(
|
|
lora_name=data,
|
|
)
|
|
elif op_type == OperationType.FORWARD:
|
|
prompts, adapters = zip(*data)
|
|
result = session.forward(
|
|
prompts=list(prompts),
|
|
lora_paths=list(adapters),
|
|
max_new_tokens=max_new_tokens,
|
|
expected_error=expected_error,
|
|
)
|
|
if not expected_error:
|
|
forward_outputs.append(result)
|
|
|
|
return forward_outputs
|
|
|
|
def _run_dynamic_adapter_updates(
|
|
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
|
|
):
|
|
for case_idx, test_case in enumerate(test_cases, start=1):
|
|
print("=" * 100)
|
|
print(
|
|
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
|
|
)
|
|
print("=" * 100)
|
|
|
|
print(
|
|
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
|
)
|
|
# Test dynamic loading of adapters
|
|
dynamic_output = self._run_operation_sequence(
|
|
mode=mode,
|
|
initial_adapters=test_case.initial_adapters,
|
|
enable_lora=test_case.enable_lora,
|
|
base=test_case.base,
|
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
|
max_loaded_loras=test_case.max_loaded_loras,
|
|
op_sequence=test_case.op_sequence,
|
|
max_new_tokens=test_case.max_new_tokens,
|
|
max_lora_rank=test_case.max_lora_rank,
|
|
lora_target_modules=test_case.lora_target_modules,
|
|
)
|
|
|
|
# static loading
|
|
forward_ops = [
|
|
x
|
|
for x in test_case.op_sequence
|
|
if x.type == OperationType.FORWARD and x.expected_error is None
|
|
]
|
|
|
|
if not forward_ops:
|
|
print(
|
|
f"No forward operations found in test case {case_idx}. Skipping static pass."
|
|
)
|
|
continue
|
|
|
|
print("=" * 100)
|
|
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
|
static_output = self._run_operation_sequence(
|
|
mode=mode,
|
|
initial_adapters=test_case.all_adapters,
|
|
enable_lora=test_case.enable_lora,
|
|
base=test_case.base,
|
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
|
op_sequence=forward_ops,
|
|
max_new_tokens=test_case.max_new_tokens,
|
|
)
|
|
|
|
print(f"Dynamic output: {dynamic_output}")
|
|
print(f"Static output: {static_output}")
|
|
print("=" * 100)
|
|
self.assertEqual(
|
|
len(dynamic_output),
|
|
len(static_output),
|
|
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
|
)
|
|
for i, (dynamic, static) in enumerate(
|
|
zip(dynamic_output, static_output), start=1
|
|
):
|
|
self.assertEqual(
|
|
len(dynamic),
|
|
len(static),
|
|
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
|
)
|
|
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
|
d_out = d_out.strip()
|
|
s_out = s_out.strip()
|
|
self.assertEqual(
|
|
d_out,
|
|
s_out,
|
|
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
|
)
|
|
|
|
def test_dynamic_lora_update_engine(self):
|
|
"""
|
|
Test dynamic LoRA updates in engine mode.
|
|
"""
|
|
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
|
|
self._run_dynamic_adapter_updates(
|
|
mode=LoRAUpdateTestSessionMode.ENGINE,
|
|
test_cases=test_cases,
|
|
)
|
|
|
|
def test_dynamic_lora_update_server(self):
|
|
"""
|
|
Test dynamic LoRA updates in server mode.
|
|
"""
|
|
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
|
|
self._run_dynamic_adapter_updates(
|
|
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
mp.set_start_method("spawn")
|
|
except RuntimeError:
|
|
pass
|
|
|
|
unittest.main(warnings="ignore")
|