Improve error handling for requests with unloaded LoRA path(s) (#7642)

This commit is contained in:
Lifu Huang
2025-07-01 20:05:34 -07:00
committed by GitHub
parent f18a8fddd4
commit 1a08358aed
3 changed files with 135 additions and 20 deletions

View File

@@ -16,7 +16,7 @@ import multiprocessing as mp
import unittest
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import requests
import torch
@@ -42,14 +42,16 @@ PROMPTS = [
class OperationType(Enum):
LOAD = "load"
UNLOAD = "unload"
NOOP = "noop"
FORWARD = "forward"
EXPECT_ERROR = "expect_error"
@dataclass
class Operation:
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
type: OperationType
data: Optional[str]
# Data associated with the operation. Exact type varies depending on the operation
data: Optional[Any]
@dataclass
@@ -62,7 +64,7 @@ class TestCase:
max_new_tokens: int = 32
def create_batch_data(adapters: Union[str, list]) -> dict:
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
if not isinstance(adapters, list):
adapters = [adapters]
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
@@ -80,6 +82,26 @@ TEST_CASES = [
],
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
op_sequence=[
Operation(
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation(
type=OperationType.LOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
@@ -102,6 +124,13 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data(
@@ -115,6 +144,15 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
@@ -149,6 +187,22 @@ TEST_CASES = [
type=OperationType.FORWARD,
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data(
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
),
"not loaded",
),
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
"not loaded",
),
),
Operation(
type=OperationType.LOAD,
data="pbevan11/llama-3.1-8b-ocr-correction",
@@ -157,6 +211,13 @@ TEST_CASES = [
type=OperationType.UNLOAD,
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
),
Operation(
type=OperationType.EXPECT_ERROR,
data=(
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
"not loaded",
),
),
Operation(
type=OperationType.FORWARD,
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
"""
response = self.handle.batch_forward(
prompts=prompts,
lora_paths=lora_paths,
max_new_tokens=max_new_tokens,
)
output_strs = response.output_strs
try:
response = self.handle.batch_forward(
prompts=prompts,
lora_paths=lora_paths,
max_new_tokens=max_new_tokens,
)
except ValueError as e:
if expected_error:
error_message = str(e)
self.testcase.assertIn(expected_error, error_message)
print(f"Received error as expected: {error_message}")
return error_message
print(f"output_strs: {output_strs}")
return output_strs
raise e
self.testcase.assertEqual(len(response.output_strs), len(prompts))
output = response.output_strs
print(f"output_strs: {output}")
return output
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
prompts: List[str],
lora_paths: List[str],
max_new_tokens: int = 32,
expected_error: str = None,
):
"""
Perform a batch forward pass with the current set of loaded LoRA adapters.
@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
},
},
)
self.testcase.assertTrue(response.ok)
output_strs = [r["text"] for r in response.json()]
print(f"output_strs: {output_strs}")
return output_strs
if expected_error:
self.testcase.assertEqual(response.status_code, 400)
self.testcase.assertIn(expected_error, response.text)
output = response.text
print(f"Received error as expected: {response.text}")
return output
else:
self.testcase.assertTrue(response.ok)
output = [r["text"] for r in response.json()]
self.testcase.assertEqual(len(output), len(prompts))
print(f"output_strs: {output}")
return output
# Factory function to create the appropriate LoRA test session based on mode
@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
max_new_tokens=max_new_tokens,
)
forward_outputs.append(result)
elif op_type == OperationType.EXPECT_ERROR:
input_data, expected_error = data
prompts, adapters = zip(*input_data)
result = session.forward(
prompts=list(prompts),
lora_paths=list(adapters),
max_new_tokens=max_new_tokens,
expected_error=expected_error,
)
return forward_outputs
def test_dynamic_adapter_updates(self):
for case_idx, test_case in enumerate(TEST_CASES, start=1):
for mode in [
LoRAUpdateTestSessionMode.SERVER,
LoRAUpdateTestSessionMode.ENGINE,
LoRAUpdateTestSessionMode.SERVER,
]:
print("=" * 100)
print(f"Starting test case {case_idx} in {mode.value} mode.")