[Feature] Initial support for multi-LoRA serving (#1307)
This commit is contained in:
@@ -101,6 +101,10 @@ class ServerArgs:
|
||||
enable_mla: bool = False
|
||||
triton_attention_reduce_in_fp32: bool = False
|
||||
|
||||
# LoRA
|
||||
lora_paths: Optional[List[str]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
@@ -522,6 +526,21 @@ class ServerArgs:
|
||||
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
||||
)
|
||||
|
||||
# LoRA options
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The list of LoRA adapters.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-loras-per-batch",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Maximum number of adapters for a running batch, include base-only request",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
@@ -539,6 +558,12 @@ class ServerArgs:
|
||||
assert not (
|
||||
self.dp_size > 1 and self.node_rank is not None
|
||||
), "multi-node data parallel is not supported"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
and (self.lora_paths is None or self.disable_cuda_graph)
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
Reference in New Issue
Block a user