Improve error handling for requests with unloaded LoRA path(s) (#7642)
This commit is contained in:
@@ -16,7 +16,7 @@ import multiprocessing as mp
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -42,14 +42,16 @@ PROMPTS = [
|
||||
class OperationType(Enum):
|
||||
LOAD = "load"
|
||||
UNLOAD = "unload"
|
||||
NOOP = "noop"
|
||||
FORWARD = "forward"
|
||||
EXPECT_ERROR = "expect_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
|
||||
type: OperationType
|
||||
data: Optional[str]
|
||||
# Data associated with the operation. Exact type varies depending on the operation
|
||||
data: Optional[Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -62,7 +64,7 @@ class TestCase:
|
||||
max_new_tokens: int = 32
|
||||
|
||||
|
||||
def create_batch_data(adapters: Union[str, list]) -> dict:
|
||||
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||
if not isinstance(adapters, list):
|
||||
adapters = [adapters]
|
||||
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||
@@ -80,6 +82,26 @@ TEST_CASES = [
|
||||
],
|
||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
@@ -102,6 +124,13 @@ TEST_CASES = [
|
||||
type=OperationType.UNLOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
@@ -115,6 +144,15 @@ TEST_CASES = [
|
||||
type=OperationType.UNLOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
@@ -149,6 +187,22 @@ TEST_CASES = [
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
@@ -157,6 +211,13 @@ TEST_CASES = [
|
||||
type=OperationType.UNLOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
"""
|
||||
response = self.handle.batch_forward(
|
||||
prompts=prompts,
|
||||
lora_paths=lora_paths,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
output_strs = response.output_strs
|
||||
try:
|
||||
response = self.handle.batch_forward(
|
||||
prompts=prompts,
|
||||
lora_paths=lora_paths,
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
except ValueError as e:
|
||||
if expected_error:
|
||||
error_message = str(e)
|
||||
self.testcase.assertIn(expected_error, error_message)
|
||||
print(f"Received error as expected: {error_message}")
|
||||
return error_message
|
||||
|
||||
print(f"output_strs: {output_strs}")
|
||||
return output_strs
|
||||
raise e
|
||||
|
||||
self.testcase.assertEqual(len(response.output_strs), len(prompts))
|
||||
output = response.output_strs
|
||||
print(f"output_strs: {output}")
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
},
|
||||
},
|
||||
)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
output_strs = [r["text"] for r in response.json()]
|
||||
|
||||
print(f"output_strs: {output_strs}")
|
||||
return output_strs
|
||||
if expected_error:
|
||||
self.testcase.assertEqual(response.status_code, 400)
|
||||
self.testcase.assertIn(expected_error, response.text)
|
||||
output = response.text
|
||||
print(f"Received error as expected: {response.text}")
|
||||
return output
|
||||
else:
|
||||
self.testcase.assertTrue(response.ok)
|
||||
output = [r["text"] for r in response.json()]
|
||||
self.testcase.assertEqual(len(output), len(prompts))
|
||||
print(f"output_strs: {output}")
|
||||
return output
|
||||
|
||||
|
||||
# Factory function to create the appropriate LoRA test session based on mode
|
||||
@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
forward_outputs.append(result)
|
||||
elif op_type == OperationType.EXPECT_ERROR:
|
||||
input_data, expected_error = data
|
||||
prompts, adapters = zip(*input_data)
|
||||
result = session.forward(
|
||||
prompts=list(prompts),
|
||||
lora_paths=list(adapters),
|
||||
max_new_tokens=max_new_tokens,
|
||||
expected_error=expected_error,
|
||||
)
|
||||
|
||||
return forward_outputs
|
||||
|
||||
def test_dynamic_adapter_updates(self):
|
||||
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
||||
for mode in [
|
||||
LoRAUpdateTestSessionMode.SERVER,
|
||||
LoRAUpdateTestSessionMode.ENGINE,
|
||||
LoRAUpdateTestSessionMode.SERVER,
|
||||
]:
|
||||
print("=" * 100)
|
||||
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
||||
|
||||
Reference in New Issue
Block a user