Improve error handling for requests with unloaded LoRA path(s) (#7642)
This commit is contained in:
@@ -240,6 +240,12 @@ class TokenizerManager:
|
|||||||
revision=server_args.revision,
|
revision=server_args.revision,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
|
||||||
|
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
|
||||||
|
self.loaded_lora_adapters: Dict[str, str] = dict(
|
||||||
|
self.server_args.lora_paths or {}
|
||||||
|
)
|
||||||
|
|
||||||
# Store states
|
# Store states
|
||||||
self.no_create_loop = False
|
self.no_create_loop = False
|
||||||
self.rid_to_state: Dict[str, ReqState] = {}
|
self.rid_to_state: Dict[str, ReqState] = {}
|
||||||
@@ -549,6 +555,8 @@ class TokenizerManager:
|
|||||||
"The server is not configured to enable custom logit processor. "
|
"The server is not configured to enable custom logit processor. "
|
||||||
"Please set `--enable-custom-logits-processor` to enable this feature."
|
"Please set `--enable-custom-logits-processor` to enable this feature."
|
||||||
)
|
)
|
||||||
|
if self.server_args.lora_paths and obj.lora_path:
|
||||||
|
self._validate_lora_adapters(obj)
|
||||||
|
|
||||||
def _validate_input_ids_in_vocab(
|
def _validate_input_ids_in_vocab(
|
||||||
self, input_ids: List[int], vocab_size: int
|
self, input_ids: List[int], vocab_size: int
|
||||||
@@ -662,6 +670,21 @@ class TokenizerManager:
|
|||||||
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _validate_lora_adapters(self, obj: GenerateReqInput):
|
||||||
|
"""Validate that the requested LoRA adapters are loaded."""
|
||||||
|
requested_adapters = (
|
||||||
|
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
|
||||||
|
)
|
||||||
|
loaded_adapters = (
|
||||||
|
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
|
||||||
|
)
|
||||||
|
unloaded_adapters = requested_adapters - loaded_adapters
|
||||||
|
if unloaded_adapters:
|
||||||
|
raise ValueError(
|
||||||
|
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
|
||||||
|
f"Loaded adapters: {loaded_adapters}."
|
||||||
|
)
|
||||||
|
|
||||||
def _send_one_request(
|
def _send_one_request(
|
||||||
self,
|
self,
|
||||||
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
||||||
@@ -988,6 +1011,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
self.loaded_lora_adapters = result.loaded_adapters
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def unload_lora_adapter(
|
async def unload_lora_adapter(
|
||||||
@@ -1009,6 +1033,7 @@ class TokenizerManager:
|
|||||||
|
|
||||||
async with self.model_update_lock.writer_lock:
|
async with self.model_update_lock.writer_lock:
|
||||||
result = (await self.update_lora_adapter_communicator(obj))[0]
|
result = (await self.update_lora_adapter_communicator(obj))[0]
|
||||||
|
self.loaded_lora_adapters = result.loaded_adapters
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def get_weights_by_name(
|
async def get_weights_by_name(
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import List, Literal, Optional
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
||||||
from sglang.srt.reasoning_parser import ReasoningParser
|
from sglang.srt.reasoning_parser import ReasoningParser
|
||||||
@@ -131,7 +131,7 @@ class ServerArgs:
|
|||||||
preferred_sampling_params: Optional[str] = None
|
preferred_sampling_params: Optional[str] = None
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
lora_paths: Optional[List[str]] = None
|
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import multiprocessing as mp
|
|||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
@@ -42,14 +42,16 @@ PROMPTS = [
|
|||||||
class OperationType(Enum):
|
class OperationType(Enum):
|
||||||
LOAD = "load"
|
LOAD = "load"
|
||||||
UNLOAD = "unload"
|
UNLOAD = "unload"
|
||||||
NOOP = "noop"
|
|
||||||
FORWARD = "forward"
|
FORWARD = "forward"
|
||||||
|
EXPECT_ERROR = "expect_error"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Operation:
|
class Operation:
|
||||||
|
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
|
||||||
type: OperationType
|
type: OperationType
|
||||||
data: Optional[str]
|
# Data associated with the operation. Exact type varies depending on the operation
|
||||||
|
data: Optional[Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -62,7 +64,7 @@ class TestCase:
|
|||||||
max_new_tokens: int = 32
|
max_new_tokens: int = 32
|
||||||
|
|
||||||
|
|
||||||
def create_batch_data(adapters: Union[str, list]) -> dict:
|
def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||||
if not isinstance(adapters, list):
|
if not isinstance(adapters, list):
|
||||||
adapters = [adapters]
|
adapters = [adapters]
|
||||||
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||||
@@ -80,6 +82,26 @@ TEST_CASES = [
|
|||||||
],
|
],
|
||||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
||||||
op_sequence=[
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.LOAD,
|
type=OperationType.LOAD,
|
||||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
@@ -102,6 +124,13 @@ TEST_CASES = [
|
|||||||
type=OperationType.UNLOAD,
|
type=OperationType.UNLOAD,
|
||||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
),
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data(
|
data=create_batch_data(
|
||||||
@@ -115,6 +144,15 @@ TEST_CASES = [
|
|||||||
type=OperationType.UNLOAD,
|
type=OperationType.UNLOAD,
|
||||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
),
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
@@ -149,6 +187,22 @@ TEST_CASES = [
|
|||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
),
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.LOAD,
|
type=OperationType.LOAD,
|
||||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
@@ -157,6 +211,13 @@ TEST_CASES = [
|
|||||||
type=OperationType.UNLOAD,
|
type=OperationType.UNLOAD,
|
||||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
),
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.EXPECT_ERROR,
|
||||||
|
data=(
|
||||||
|
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
"not loaded",
|
||||||
|
),
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
@@ -332,19 +393,31 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
|
expected_error: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
response = self.handle.batch_forward(
|
response = self.handle.batch_forward(
|
||||||
prompts=prompts,
|
prompts=prompts,
|
||||||
lora_paths=lora_paths,
|
lora_paths=lora_paths,
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
)
|
||||||
output_strs = response.output_strs
|
except ValueError as e:
|
||||||
|
if expected_error:
|
||||||
|
error_message = str(e)
|
||||||
|
self.testcase.assertIn(expected_error, error_message)
|
||||||
|
print(f"Received error as expected: {error_message}")
|
||||||
|
return error_message
|
||||||
|
|
||||||
print(f"output_strs: {output_strs}")
|
raise e
|
||||||
return output_strs
|
|
||||||
|
self.testcase.assertEqual(len(response.output_strs), len(prompts))
|
||||||
|
output = response.output_strs
|
||||||
|
print(f"output_strs: {output}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||||
@@ -426,6 +499,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
|
expected_error: str = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
@@ -442,11 +516,18 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if expected_error:
|
||||||
|
self.testcase.assertEqual(response.status_code, 400)
|
||||||
|
self.testcase.assertIn(expected_error, response.text)
|
||||||
|
output = response.text
|
||||||
|
print(f"Received error as expected: {response.text}")
|
||||||
|
return output
|
||||||
|
else:
|
||||||
self.testcase.assertTrue(response.ok)
|
self.testcase.assertTrue(response.ok)
|
||||||
output_strs = [r["text"] for r in response.json()]
|
output = [r["text"] for r in response.json()]
|
||||||
|
self.testcase.assertEqual(len(output), len(prompts))
|
||||||
print(f"output_strs: {output_strs}")
|
print(f"output_strs: {output}")
|
||||||
return output_strs
|
return output
|
||||||
|
|
||||||
|
|
||||||
# Factory function to create the appropriate LoRA test session based on mode
|
# Factory function to create the appropriate LoRA test session based on mode
|
||||||
@@ -535,14 +616,23 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
)
|
||||||
forward_outputs.append(result)
|
forward_outputs.append(result)
|
||||||
|
elif op_type == OperationType.EXPECT_ERROR:
|
||||||
|
input_data, expected_error = data
|
||||||
|
prompts, adapters = zip(*input_data)
|
||||||
|
result = session.forward(
|
||||||
|
prompts=list(prompts),
|
||||||
|
lora_paths=list(adapters),
|
||||||
|
max_new_tokens=max_new_tokens,
|
||||||
|
expected_error=expected_error,
|
||||||
|
)
|
||||||
|
|
||||||
return forward_outputs
|
return forward_outputs
|
||||||
|
|
||||||
def test_dynamic_adapter_updates(self):
|
def test_dynamic_adapter_updates(self):
|
||||||
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
||||||
for mode in [
|
for mode in [
|
||||||
LoRAUpdateTestSessionMode.SERVER,
|
|
||||||
LoRAUpdateTestSessionMode.ENGINE,
|
LoRAUpdateTestSessionMode.ENGINE,
|
||||||
|
LoRAUpdateTestSessionMode.SERVER,
|
||||||
]:
|
]:
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
||||||
|
|||||||
Reference in New Issue
Block a user