Improve error handling for requests with unloaded LoRA path(s) (#7642)

This commit is contained in:
Lifu Huang
2025-07-01 20:05:34 -07:00
committed by GitHub
parent f18a8fddd4
commit 1a08358aed
3 changed files with 135 additions and 20 deletions

View File

@@ -240,6 +240,12 @@ class TokenizerManager:
revision=server_args.revision,
)
# Initialize loaded loRA adapters with the initial lora paths in the server_args.
# This list will be updated when new LoRA adapters are loaded or unloaded dynamically.
self.loaded_lora_adapters: Dict[str, str] = dict(
self.server_args.lora_paths or {}
)
# Store states
self.no_create_loop = False
self.rid_to_state: Dict[str, ReqState] = {}
@@ -549,6 +555,8 @@ class TokenizerManager:
"The server is not configured to enable custom logit processor. "
"Please set `--enable-custom-logits-processor` to enable this feature."
)
if self.server_args.lora_paths and obj.lora_path:
self._validate_lora_adapters(obj)
def _validate_input_ids_in_vocab(
self, input_ids: List[int], vocab_size: int
@@ -662,6 +670,21 @@ class TokenizerManager:
"Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
)
def _validate_lora_adapters(self, obj: GenerateReqInput):
"""Validate that the requested LoRA adapters are loaded."""
requested_adapters = (
set(obj.lora_path) if isinstance(obj.lora_path, list) else {obj.lora_path}
)
loaded_adapters = (
self.loaded_lora_adapters.keys() if self.loaded_lora_adapters else set()
)
unloaded_adapters = requested_adapters - loaded_adapters
if unloaded_adapters:
raise ValueError(
f"The following requested LoRA adapters are not loaded: {unloaded_adapters}\n"
f"Loaded adapters: {loaded_adapters}."
)
def _send_one_request(
self,
obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -988,6 +1011,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def unload_lora_adapter(
@@ -1009,6 +1033,7 @@ class TokenizerManager:
async with self.model_update_lock.writer_lock:
result = (await self.update_lora_adapter_communicator(obj))[0]
self.loaded_lora_adapters = result.loaded_adapters
return result
async def get_weights_by_name(

View File

@@ -20,7 +20,7 @@ import logging
import os
import random
import tempfile
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
from sglang.srt.reasoning_parser import ReasoningParser
@@ -131,7 +131,7 @@ class ServerArgs:
preferred_sampling_params: Optional[str] = None
# LoRA
lora_paths: Optional[List[str]] = None
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
max_loras_per_batch: int = 8
lora_backend: str = "triton"