[Feature] Initial support for multi-LoRA serving (#1307)

This commit is contained in:
Ying Sheng
2024-09-12 16:46:14 -07:00
committed by GitHub
parent c33d82a211
commit 712216928f
21 changed files with 1435 additions and 22 deletions

View File

@@ -101,6 +101,10 @@ class ServerArgs:
enable_mla: bool = False
triton_attention_reduce_in_fp32: bool = False
# LoRA
lora_paths: Optional[List[str]] = None
max_loras_per_batch: int = 8
def __post_init__(self):
# Set missing default values
if self.tokenizer_path is None:
@@ -522,6 +526,21 @@ class ServerArgs:
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
# LoRA options
parser.add_argument(
"--lora-paths",
type=str,
nargs="*",
default=None,
help="The list of LoRA adapters.",
)
parser.add_argument(
"--max-loras-per-batch",
type=int,
default=8,
help="Maximum number of adapters for a running batch, include base-only request",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size
@@ -539,6 +558,12 @@ class ServerArgs:
assert not (
self.dp_size > 1 and self.node_rank is not None
), "multi-node data parallel is not supported"
assert (
self.max_loras_per_batch > 0
# FIXME
and (self.lora_paths is None or self.disable_cuda_graph)
and (self.lora_paths is None or self.disable_radix_cache)
), "compatibility of lora and cuda graph and radix attention is in progress"
def prepare_server_args(argv: List[str]) -> ServerArgs: