[FEATURE] Add OpenAI-Compatible LoRA Adapter Selection (#11570)
This commit is contained in:
@@ -204,7 +204,10 @@ class BatchResponse(BaseModel):
|
||||
class CompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
model: str = Field(
|
||||
default=DEFAULT_MODEL_NAME,
|
||||
description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.",
|
||||
)
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
best_of: Optional[int] = None
|
||||
echo: bool = False
|
||||
@@ -441,7 +444,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
model: str = DEFAULT_MODEL_NAME
|
||||
model: str = Field(
|
||||
default=DEFAULT_MODEL_NAME,
|
||||
description="Model name. Supports LoRA adapters via 'base-model:adapter-name' syntax.",
|
||||
)
|
||||
frequency_penalty: float = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
logprobs: bool = False
|
||||
@@ -1099,7 +1105,7 @@ class ResponsesResponse(BaseModel):
|
||||
Union[
|
||||
ResponseOutputItem, ResponseReasoningItem, ResponseFunctionToolCall
|
||||
]
|
||||
]
|
||||
],
|
||||
) -> bool:
|
||||
if not items:
|
||||
return False
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import orjson
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -35,6 +35,52 @@ class OpenAIServingBase(ABC):
|
||||
else None
|
||||
)
|
||||
|
||||
def _parse_model_parameter(self, model: str) -> Tuple[str, Optional[str]]:
|
||||
"""Parse 'base-model:adapter-name' syntax to extract LoRA adapter.
|
||||
|
||||
Returns (base_model, adapter_name) or (model, None) if no colon present.
|
||||
"""
|
||||
if ":" not in model:
|
||||
return model, None
|
||||
|
||||
# Split on first colon only to handle model paths with multiple colons
|
||||
parts = model.split(":", 1)
|
||||
base_model = parts[0].strip()
|
||||
adapter_name = parts[1].strip() or None
|
||||
|
||||
return base_model, adapter_name
|
||||
|
||||
def _resolve_lora_path(
|
||||
self,
|
||||
request_model: str,
|
||||
explicit_lora_path: Optional[Union[str, List[Optional[str]]]],
|
||||
) -> Optional[Union[str, List[Optional[str]]]]:
|
||||
"""Resolve LoRA adapter with priority: model parameter > explicit lora_path.
|
||||
|
||||
Returns adapter name or None. Supports both single values and lists (batches).
|
||||
"""
|
||||
_, adapter_from_model = self._parse_model_parameter(request_model)
|
||||
|
||||
# Model parameter adapter takes precedence
|
||||
if adapter_from_model is not None:
|
||||
return adapter_from_model
|
||||
|
||||
# Fall back to explicit lora_path
|
||||
return explicit_lora_path
|
||||
|
||||
def _validate_lora_enabled(self, adapter_name: str) -> None:
|
||||
"""Check that LoRA is enabled before attempting to use an adapter.
|
||||
|
||||
Raises ValueError with actionable guidance if --enable-lora flag is missing.
|
||||
Adapter existence is validated later by TokenizerManager.lora_registry.
|
||||
"""
|
||||
if not self.tokenizer_manager.server_args.enable_lora:
|
||||
raise ValueError(
|
||||
f"LoRA adapter '{adapter_name}' was requested, but LoRA is not enabled. "
|
||||
"Please launch the server with --enable-lora flag and preload adapters "
|
||||
"using --lora-paths or /load_lora_adapter endpoint."
|
||||
)
|
||||
|
||||
async def handle_request(
|
||||
self, request: OpenAIServingRequest, raw_request: Request
|
||||
) -> Union[Any, StreamingResponse, ErrorResponse]:
|
||||
|
||||
@@ -164,6 +164,17 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
|
||||
# Resolve LoRA adapter from model parameter or explicit lora_path
|
||||
lora_path = self._resolve_lora_path(request.model, request.lora_path)
|
||||
if lora_path:
|
||||
first_adapter = (
|
||||
lora_path
|
||||
if isinstance(lora_path, str)
|
||||
else next((a for a in lora_path if a), None)
|
||||
)
|
||||
if first_adapter:
|
||||
self._validate_lora_enabled(first_adapter)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
image_data=processed_messages.image_data,
|
||||
@@ -176,7 +187,7 @@ class OpenAIServingChat(OpenAIServingBase):
|
||||
stream=request.stream,
|
||||
return_text_in_logprobs=True,
|
||||
modalities=processed_messages.modalities,
|
||||
lora_path=request.lora_path,
|
||||
lora_path=lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
|
||||
@@ -93,6 +93,17 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
# Extract custom labels from raw request headers
|
||||
custom_labels = self.extract_custom_labels(raw_request)
|
||||
|
||||
# Resolve LoRA adapter from model parameter or explicit lora_path
|
||||
lora_path = self._resolve_lora_path(request.model, request.lora_path)
|
||||
if lora_path:
|
||||
first_adapter = (
|
||||
lora_path
|
||||
if isinstance(lora_path, str)
|
||||
else next((a for a in lora_path if a), None)
|
||||
)
|
||||
if first_adapter:
|
||||
self._validate_lora_enabled(first_adapter)
|
||||
|
||||
adapted_request = GenerateReqInput(
|
||||
**prompt_kwargs,
|
||||
sampling_params=sampling_params,
|
||||
@@ -101,7 +112,7 @@ class OpenAIServingCompletion(OpenAIServingBase):
|
||||
logprob_start_len=logprob_start_len,
|
||||
return_text_in_logprobs=True,
|
||||
stream=request.stream,
|
||||
lora_path=request.lora_path,
|
||||
lora_path=lora_path,
|
||||
bootstrap_host=request.bootstrap_host,
|
||||
bootstrap_port=request.bootstrap_port,
|
||||
bootstrap_room=request.bootstrap_room,
|
||||
|
||||
Reference in New Issue
Block a user