Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)
This commit is contained in:
@@ -33,6 +33,10 @@
|
||||
"\n",
|
||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||
"\n",
|
||||
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||
"\n",
|
||||
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
|
||||
"\n",
|
||||
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
|
||||
"\n",
|
||||
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
|
||||
@@ -176,6 +180,241 @@
|
||||
"terminate_process(server_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Dynamic LoRA loading"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Basic Usage\n",
|
||||
"\n",
|
||||
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
|
||||
"\n",
|
||||
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" \"\"\"\n",
|
||||
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
|
||||
" --cuda-graph-max-bs 2 \\\n",
|
||||
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
||||
" --disable-radix-cache\n",
|
||||
" \"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||
"wait_for_server(url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/load_lora_adapter\",\n",
|
||||
" json={\n",
|
||||
" \"lora_name\": \"lora1\",\n",
|
||||
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if response.status_code == 200:\n",
|
||||
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||
"else:\n",
|
||||
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/generate\",\n",
|
||||
" json={\n",
|
||||
" \"text\": [\n",
|
||||
" \"List 3 countries and their capitals.\",\n",
|
||||
" \"List 3 countries and their capitals.\",\n",
|
||||
" ],\n",
|
||||
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
|
||||
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/unload_lora_adapter\",\n",
|
||||
" json={\n",
|
||||
" \"lora_name\": \"lora0\",\n",
|
||||
" },\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/load_lora_adapter\",\n",
|
||||
" json={\n",
|
||||
" \"lora_name\": \"lora2\",\n",
|
||||
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if response.status_code == 200:\n",
|
||||
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||
"else:\n",
|
||||
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/generate\",\n",
|
||||
" json={\n",
|
||||
" \"text\": [\n",
|
||||
" \"List 3 countries and their capitals.\",\n",
|
||||
" \"List 3 countries and their capitals.\",\n",
|
||||
" ],\n",
|
||||
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
|
||||
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"terminate_process(server_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Advanced: hosting adapters of different shapes\n",
|
||||
"\n",
|
||||
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
|
||||
"\n",
|
||||
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
|
||||
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
|
||||
"# We are adding it here just to demonstrate usage.\n",
|
||||
"server_process, port = launch_server_cmd(\n",
|
||||
" f\"\"\"\n",
|
||||
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||
" --lora-paths lora0={lora0} \\\n",
|
||||
" --cuda-graph-max-bs 2 \\\n",
|
||||
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
||||
" --disable-radix-cache\n",
|
||||
" --max-lora-rank 64\n",
|
||||
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
|
||||
" \"\"\"\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||
"wait_for_server(url)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"response = requests.post(\n",
|
||||
" url + \"/load_lora_adapter\",\n",
|
||||
" json={\n",
|
||||
" \"lora_name\": \"lora1\",\n",
|
||||
" \"lora_path\": lora1,\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"if response.status_code == 200:\n",
|
||||
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||
"else:\n",
|
||||
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||
"json_data = {\n",
|
||||
" \"text\": [\n",
|
||||
" \"List 3 countries and their capitals.\",\n",
|
||||
" \"AI is a field of computer science focused on\",\n",
|
||||
" ],\n",
|
||||
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||
" # The first input uses lora0, and the second input uses lora1\n",
|
||||
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
|
||||
"}\n",
|
||||
"response = requests.post(\n",
|
||||
" url + \"/generate\",\n",
|
||||
" json=json_data,\n",
|
||||
")\n",
|
||||
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
|
||||
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"terminate_process(server_process)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
||||
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
||||
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
||||
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
||||
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
|
||||
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
|
||||
|
||||
## Kernel backend
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# and "Punica: Multi-Tenant LoRA Serving"
|
||||
|
||||
import logging
|
||||
from typing import Dict, Set, Tuple
|
||||
from typing import Dict, Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -53,6 +53,8 @@ class LoRAManager:
|
||||
lora_backend: str = "triton",
|
||||
tp_size: int = 1,
|
||||
tp_rank: int = 0,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
target_modules: Optional[Iterable[str]] = None,
|
||||
):
|
||||
self.base_model: torch.nn.Module = base_model
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
@@ -62,6 +64,10 @@ class LoRAManager:
|
||||
self.device: torch.device = next(self.base_model.parameters()).device
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: Optional[int] = max_lora_rank
|
||||
self.target_modules: Optional[Set[str]] = (
|
||||
set(target_modules) if target_modules else None
|
||||
)
|
||||
|
||||
# LoRA backend for running sgemm kernels
|
||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||
@@ -153,7 +159,9 @@ class LoRAManager:
|
||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||
|
||||
try:
|
||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
||||
new_adapter = LoRAConfig(lora_path)
|
||||
self.validate_new_adapter(lora_name, new_adapter)
|
||||
self.configs[lora_name] = new_adapter
|
||||
except Exception as e:
|
||||
success = False
|
||||
error_message = (
|
||||
@@ -168,6 +176,21 @@ class LoRAManager:
|
||||
error_message=error_message,
|
||||
)
|
||||
|
||||
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
||||
"""
|
||||
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||
"""
|
||||
|
||||
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
||||
lora_config
|
||||
)
|
||||
if incompatible:
|
||||
raise ValueError(
|
||||
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
|
||||
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
|
||||
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
|
||||
)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||
"""
|
||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||
@@ -214,7 +237,7 @@ class LoRAManager:
|
||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||
if lora_path is not None:
|
||||
lora = self.loras[lora_path]
|
||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
||||
lora_ranks[weight_indices[i]] = lora.config.r
|
||||
scalings[weight_indices[i]] = lora.scaling
|
||||
|
||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||
@@ -319,7 +342,7 @@ class LoRAManager:
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
module.set_lora_info(
|
||||
self.memory_pool.get_tensor(
|
||||
@@ -351,58 +374,67 @@ class LoRAManager:
|
||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||
}
|
||||
|
||||
# Initialize memory pool
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
self.base_hf_config,
|
||||
self.max_loras_per_batch,
|
||||
self.dtype,
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
||||
# It is initialized lazily when the first LoRA adapter is loaded.
|
||||
self.memory_pool: Optional[LoRAMemoryPool] = None
|
||||
|
||||
def update_state_from_configs(self):
|
||||
"""
|
||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||
|
||||
This includes:
|
||||
- Initializing LoRA adapters if they are not already loaded.
|
||||
- Collect all LoRA weight names based on the current loaded adapters.
|
||||
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
||||
- Preparing the GPU buffer pool for active LoRA weights.
|
||||
"""
|
||||
|
||||
# Target module names in huggingface lora configs.
|
||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
||||
hf_target_module_names: Set[str] = set()
|
||||
for config in self.configs.values():
|
||||
hf_target_module_names.update(config.target_modules)
|
||||
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
||||
|
||||
# Loads / unloads LoRA adapters based on the latest configs.
|
||||
self.update_lora_adapters()
|
||||
# Apply the latest LoRA configurations to the internal state for inferencing.
|
||||
self.apply_lora_configs()
|
||||
|
||||
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
||||
#
|
||||
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
||||
# multiple places to support the new weight names when the first adapter targeting such weight names
|
||||
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
||||
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
||||
# list of LoRA weight names is expected to be extremely finite and stable.
|
||||
self.update_lora_weight_names(hf_target_module_names)
|
||||
self.update_lora_modules(hf_target_module_names)
|
||||
self.update_memory_buffers(max_lora_dim)
|
||||
def apply_lora_configs(self):
|
||||
"""
|
||||
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
||||
|
||||
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
||||
Notes:
|
||||
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
||||
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
||||
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
||||
early CY25H2.
|
||||
"""
|
||||
|
||||
if self.memory_pool is None:
|
||||
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
||||
if self.target_modules is None:
|
||||
self.target_modules = set()
|
||||
for config in self.configs.values():
|
||||
self.target_modules.update(config.target_modules)
|
||||
|
||||
if self.max_lora_rank is None:
|
||||
self.max_lora_rank = max(
|
||||
[x.hf_config["r"] for x in self.configs.values()],
|
||||
default=0,
|
||||
)
|
||||
|
||||
self.update_lora_weight_names()
|
||||
self.update_lora_modules()
|
||||
self.update_memory_buffers()
|
||||
else:
|
||||
# No-op if the memory pool can support the current LoRA configurations.
|
||||
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
||||
# module is changed once FlashInfer backend is deprecated.
|
||||
assert self.memory_pool.can_support(self.configs.values()), (
|
||||
"LoRA memory pool cannot support the current LoRA configuration. "
|
||||
"This should never happen as we should have validated adapter compatibility. "
|
||||
"Please create a Github issue to report.",
|
||||
)
|
||||
|
||||
def update_lora_weight_names(self):
|
||||
"""
|
||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||
"""
|
||||
|
||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||
for module in hf_target_names:
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||
self.lora_weight_names[0].update(lora_A)
|
||||
self.lora_weight_names[1].update(lora_B)
|
||||
|
||||
def update_lora_adapters(self):
|
||||
"""
|
||||
@@ -434,21 +466,23 @@ class LoRAManager:
|
||||
# Additional checks for flashinfer backend
|
||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||
if self.lora_backend == "flashinfer":
|
||||
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
||||
lora_dims = set(x.r for x in self.configs.values())
|
||||
scalings = set(x.scaling for x in self.loras.values())
|
||||
assert (
|
||||
len(lora_dims) == 1 and len(scalings) == 1
|
||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||
|
||||
def update_memory_buffers(self, max_lora_dim: int):
|
||||
"""
|
||||
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
||||
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
||||
are set or updated.
|
||||
"""
|
||||
|
||||
self.memory_pool.init_buffers(
|
||||
self.lora_weight_names, self.base_model, max_lora_dim
|
||||
def update_memory_buffers(self):
|
||||
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||
self.memory_pool = LoRAMemoryPool(
|
||||
base_hf_config=self.base_hf_config,
|
||||
max_loras_per_batch=self.max_loras_per_batch,
|
||||
dtype=self.dtype,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
lora_weight_names=self.lora_weight_names,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
|
||||
def set_lora_module(self, module_name, module):
|
||||
@@ -456,11 +490,11 @@ class LoRAManager:
|
||||
replace_submodule(self.base_model, module_name, lora_module)
|
||||
return lora_module
|
||||
|
||||
def update_lora_modules(self, hf_target_names: Set[str]):
|
||||
def update_lora_modules(self):
|
||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||
# e.g., {"qkv_proj", "o_proj"}
|
||||
customized_target_names = get_customized_names_from_hf_names(
|
||||
hf_target_names, self.base_model
|
||||
self.target_modules, self.base_model
|
||||
)
|
||||
|
||||
for module_name, module in self.base_model.named_modules():
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
||||
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
|
||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||
from sglang.srt.lora.lora import LoRAAdapter
|
||||
from sglang.srt.lora.lora_config import LoRAConfig
|
||||
from sglang.srt.lora.utils import (
|
||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||
LoRAType,
|
||||
get_hidden_dim,
|
||||
get_normalized_lora_weight_names,
|
||||
get_stacked_multiply,
|
||||
get_weight_name,
|
||||
)
|
||||
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
|
||||
dtype: torch.dtype,
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
max_lora_rank: int,
|
||||
lora_weight_names: Tuple[Set[str], Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
):
|
||||
self.base_hf_config: AutoConfig = base_hf_config
|
||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
|
||||
self.dtype: torch.dtype = dtype
|
||||
self.tp_size: int = tp_size
|
||||
self.tp_rank: int = tp_rank
|
||||
self.max_lora_rank: int = max_lora_rank
|
||||
|
||||
# lora weight names for LoRA A and B respectively.
|
||||
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
||||
|
||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||
# A_buffer contains num_layer number of row-major tensors with shape
|
||||
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
|
||||
# Here we don't initialize to None since None is a valid uid
|
||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||
|
||||
self.init_buffers(base_model)
|
||||
|
||||
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
|
||||
"""
|
||||
Check if the memory pool can support the given LoRA adapters.
|
||||
"""
|
||||
|
||||
def _can_support(config: LoRAConfig) -> bool:
|
||||
"""
|
||||
Check if the memory pool can support a single LoRA adapter.
|
||||
"""
|
||||
if config.r > self.max_lora_rank:
|
||||
return False
|
||||
weights_a, weights_b = get_normalized_lora_weight_names(
|
||||
config.target_modules
|
||||
)
|
||||
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
||||
self.lora_weight_names[1]
|
||||
)
|
||||
|
||||
if isinstance(config, LoRAConfig):
|
||||
return _can_support(config)
|
||||
else:
|
||||
return all(_can_support(x) for x in config)
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
) -> Tuple[int]:
|
||||
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
|
||||
max_lora_dim,
|
||||
)
|
||||
|
||||
def init_buffers(
|
||||
self,
|
||||
lora_weight_names: Tuple[Set[str]],
|
||||
base_model: torch.nn.Module,
|
||||
max_lora_dim: int,
|
||||
):
|
||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
||||
def init_buffers(self, base_model: torch.nn.Module):
|
||||
device = next(base_model.parameters()).device
|
||||
|
||||
def update_buffer(
|
||||
def init_buffer(
|
||||
buffer: Dict[str, List[torch.Tensor]],
|
||||
lora_weight_names: Set[str],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||
):
|
||||
new_weight_names = lora_weight_names - buffer.keys()
|
||||
for module_name in new_weight_names:
|
||||
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
||||
for module_name in lora_weight_names:
|
||||
lora_shape = get_lora_shape_fn(
|
||||
module_name, base_model, self.max_lora_rank
|
||||
)
|
||||
buffer[module_name] = [
|
||||
torch.empty(
|
||||
lora_shape,
|
||||
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
|
||||
for _ in range(self.num_layer)
|
||||
]
|
||||
|
||||
update_buffer(
|
||||
init_buffer(
|
||||
self.A_buffer,
|
||||
lora_weight_names[0],
|
||||
self.lora_weight_names[0],
|
||||
self.get_lora_A_shape,
|
||||
)
|
||||
|
||||
update_buffer(
|
||||
init_buffer(
|
||||
self.B_buffer,
|
||||
lora_weight_names[1],
|
||||
self.lora_weight_names[1],
|
||||
self.get_lora_B_shape,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Set, Tuple
|
||||
from typing import Iterable, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
def get_normalized_lora_weight_names(
|
||||
target_modules: Iterable[str],
|
||||
) -> Tuple[set[str], set[str]]:
|
||||
"""
|
||||
Mapping a target module name to names of the normalized LoRA weights.
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Returned tuple contains (name for Lora A, name for Lora B)
|
||||
"""
|
||||
params_mapping = {
|
||||
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||
}
|
||||
stacked = params_mapping.get(name, ([name], [name]))
|
||||
return stacked
|
||||
|
||||
result = (set(), set())
|
||||
for name in target_modules:
|
||||
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||
result[0].update(lora_a)
|
||||
result[1].update(lora_b)
|
||||
return result
|
||||
|
||||
|
||||
def get_stacked_multiply(module_name: str) -> int:
|
||||
|
||||
@@ -891,6 +891,8 @@ class ModelRunner:
|
||||
lora_backend=self.server_args.lora_backend,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
max_lora_rank=self.server_args.max_lora_rank,
|
||||
target_modules=self.server_args.lora_target_modules,
|
||||
)
|
||||
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||
if result.success:
|
||||
|
||||
@@ -134,6 +134,8 @@ class ServerArgs:
|
||||
preferred_sampling_params: Optional[str] = None
|
||||
|
||||
# LoRA
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[List[str]] = None
|
||||
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
||||
max_loras_per_batch: int = 8
|
||||
lora_backend: str = "triton"
|
||||
@@ -1129,6 +1131,28 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# LoRA
|
||||
parser.add_argument(
|
||||
"--max-lora-rank",
|
||||
default=ServerArgs.max_lora_rank,
|
||||
type=int,
|
||||
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-target-modules",
|
||||
type=str,
|
||||
choices=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
nargs="*",
|
||||
default=None,
|
||||
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-paths",
|
||||
type=str,
|
||||
|
||||
@@ -505,6 +505,8 @@ class SRTRunner:
|
||||
torchao_config: Optional[str] = None,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
sleep_on_idle=False,
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
):
|
||||
self.model_type = model_type
|
||||
self.is_generation = model_type == "generation"
|
||||
@@ -543,6 +545,8 @@ class SRTRunner:
|
||||
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
sleep_on_idle=sleep_on_idle,
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
**spec_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import multiprocessing as mp
|
||||
import unittest
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
is_in_ci,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
@@ -45,24 +46,28 @@ class OperationType(Enum):
|
||||
LOAD = "load"
|
||||
UNLOAD = "unload"
|
||||
FORWARD = "forward"
|
||||
EXPECT_ERROR = "expect_error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Operation:
|
||||
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
|
||||
# Operation type, can be LOAD, UNLOAD, FORWARD
|
||||
type: OperationType
|
||||
# Data associated with the operation. Exact type varies depending on the operation
|
||||
data: Optional[Any]
|
||||
# If the operation is expected to fail, this is the error message to expect
|
||||
expected_error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestCase:
|
||||
description: str
|
||||
base: str
|
||||
max_loras_per_batch: int
|
||||
all_adapters: List[str]
|
||||
initial_adapters: List[str]
|
||||
op_sequence: List[Operation]
|
||||
max_lora_rank: Optional[int] = None
|
||||
lora_target_modules: Optional[List] = None
|
||||
max_new_tokens: int = 32
|
||||
|
||||
|
||||
@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
||||
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||
|
||||
|
||||
TEST_CASES = [
|
||||
# basic test, no eviction
|
||||
BASIC_TESTS = [
|
||||
TestCase(
|
||||
description="dynamic lora update with initial lora_paths",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
all_adapters=[
|
||||
@@ -89,20 +94,16 @@ TEST_CASES = [
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
@@ -127,11 +128,9 @@ TEST_CASES = [
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -147,13 +146,11 @@ TEST_CASES = [
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -174,8 +171,8 @@ TEST_CASES = [
|
||||
),
|
||||
],
|
||||
),
|
||||
# Eviction
|
||||
TestCase(
|
||||
description="dynamic lora update with evictions",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=1,
|
||||
all_adapters=[
|
||||
@@ -190,20 +187,16 @@ TEST_CASES = [
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
"not loaded",
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
@@ -214,11 +207,9 @@ TEST_CASES = [
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.EXPECT_ERROR,
|
||||
data=(
|
||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
"not loaded",
|
||||
),
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
@@ -263,6 +254,253 @@ TEST_CASES = [
|
||||
],
|
||||
),
|
||||
]
|
||||
TARGET_MODULE_TESTS = [
|
||||
TestCase(
|
||||
description="Test explicitly specified lora-target-modules.",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
lora_target_modules=[
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
"o_proj",
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
"down_proj",
|
||||
],
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="Test inferred lora-target-modules - start with larger adapter",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="Test inferred lora-target-modules - start with smaller adapter",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=64,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||
],
|
||||
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
MAX_LORA_RANK_TESTS = [
|
||||
TestCase(
|
||||
description="Test explicitly specified max-lora-rank.",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
max_lora_rank=32,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||
],
|
||||
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
TestCase(
|
||||
description="test implicitly inferred max-lora-rank",
|
||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||
max_loras_per_batch=3,
|
||||
all_adapters=[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||
],
|
||||
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
|
||||
op_sequence=[
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||
expected_error="updating LoRA shapes",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||
expected_error="not loaded",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.LOAD,
|
||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
),
|
||||
),
|
||||
Operation(
|
||||
type=OperationType.FORWARD,
|
||||
data=create_batch_data(
|
||||
[
|
||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||
None,
|
||||
]
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
|
||||
|
||||
|
||||
class LoRAUpdateTestSessionMode(Enum):
|
||||
@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
|
||||
testcase: Optional[TestCase],
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
max_loras_per_batch: int = 1,
|
||||
max_loras_per_batch: int,
|
||||
max_lora_rank: Optional[int],
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
|
||||
self.testcase = testcase
|
||||
self.model_path = model_path
|
||||
self.lora_paths = lora_paths
|
||||
self.max_lora_rank = max_lora_rank
|
||||
self.lora_target_modules = lora_target_modules
|
||||
self.max_loras_per_batch = max_loras_per_batch
|
||||
self.lora_backend = lora_backend
|
||||
self.disable_cuda_graph = disable_cuda_graph
|
||||
@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
|
||||
# Don't suppress exceptions by default
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
model_path=self.model_path,
|
||||
model_type="generation",
|
||||
lora_paths=self.lora_paths,
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
lora_target_modules=self.lora_target_modules,
|
||||
lora_backend=self.lora_backend,
|
||||
torch_dtype=torch.float16,
|
||||
mem_fraction_static=MEM_FRACTION_STATIC,
|
||||
@@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
# don't suppress exceptions
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
if lora_path is None:
|
||||
lora_path = lora_name
|
||||
|
||||
self.expected_adapters.add(lora_name)
|
||||
|
||||
response = self.handle.load_lora_adapter(
|
||||
lora_name=lora_name,
|
||||
lora_path=lora_path,
|
||||
)
|
||||
self.testcase.assertTrue(response.success)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
if expected_error:
|
||||
self.testcase.assertFalse(response.success)
|
||||
self.testcase.assertIn(expected_error, response.error_message)
|
||||
print(f"Received error as expected: {response.error_message}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.success)
|
||||
loaded_adapters = set(response.loaded_adapters)
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
]
|
||||
if self.disable_cuda_graph:
|
||||
other_args.append("--disable-cuda-graph")
|
||||
if self.max_lora_rank is not None:
|
||||
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
||||
if self.lora_target_modules is not None:
|
||||
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
||||
|
||||
# launch external server
|
||||
self.handle = popen_launch_server(
|
||||
@@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
# don't suppress exceptions
|
||||
return False
|
||||
|
||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
||||
def load_lora_adapter(
|
||||
self,
|
||||
lora_name: str,
|
||||
lora_path: Optional[str] = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load a LoRA adapter by name and path.
|
||||
"""
|
||||
if lora_path is None:
|
||||
lora_path = lora_name
|
||||
|
||||
self.expected_adapters.add(lora_name)
|
||||
|
||||
response = requests.post(
|
||||
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
||||
json={"lora_name": lora_name, "lora_path": lora_path},
|
||||
)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
if expected_error:
|
||||
self.testcase.assertEqual(response.status_code, 400)
|
||||
self.testcase.assertIn(expected_error, response.text)
|
||||
print(f"Received error as expected: {response.text}")
|
||||
else:
|
||||
self.expected_adapters.add(lora_name)
|
||||
self.testcase.assertTrue(response.ok)
|
||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||
print(f"loaded_adapters: {loaded_adapters}")
|
||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||
|
||||
def unload_lora_adapter(self, lora_name: str):
|
||||
"""
|
||||
@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
prompts: List[str],
|
||||
lora_paths: List[str],
|
||||
max_new_tokens: int = 32,
|
||||
expected_error: str = None,
|
||||
expected_error: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||
@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
||||
|
||||
# Factory function to create the appropriate LoRA test session based on mode
|
||||
def LoRAUpdateTestSession(
|
||||
*,
|
||||
testcase: Optional[TestCase],
|
||||
mode: LoRAUpdateTestSessionMode,
|
||||
model_path: str,
|
||||
lora_paths: list[str],
|
||||
max_loras_per_batch: int = 1,
|
||||
lora_backend: str = "triton",
|
||||
disable_cuda_graph: bool = False,
|
||||
cuda_graph_max_bs: int = 4,
|
||||
**kwargs: Any,
|
||||
):
|
||||
common_kwargs = {
|
||||
"testcase": testcase,
|
||||
"model_path": model_path,
|
||||
"lora_paths": lora_paths,
|
||||
"max_loras_per_batch": max_loras_per_batch,
|
||||
"lora_backend": lora_backend,
|
||||
"disable_cuda_graph": disable_cuda_graph,
|
||||
"cuda_graph_max_bs": cuda_graph_max_bs,
|
||||
}
|
||||
|
||||
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
||||
return LoRAUpdateEngineTestSession(**common_kwargs)
|
||||
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
|
||||
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
||||
return LoRAUpdateServerTestSession(**common_kwargs)
|
||||
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized mode: {mode!r}")
|
||||
|
||||
@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
initial_adapters: List[str],
|
||||
max_loras_per_batch: int,
|
||||
op_sequence: List[Operation],
|
||||
max_lora_rank: Optional[int] = None,
|
||||
lora_target_modules: Optional[List[str]] = None,
|
||||
max_new_tokens: int = 32,
|
||||
) -> List[tuple]:
|
||||
"""
|
||||
@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
model_path=base,
|
||||
lora_paths=initial_adapters,
|
||||
max_loras_per_batch=max_loras_per_batch,
|
||||
max_lora_rank=max_lora_rank,
|
||||
lora_target_modules=lora_target_modules,
|
||||
) as session:
|
||||
for op in op_sequence:
|
||||
op_type = op.type
|
||||
data = op.data
|
||||
expected_error = op.expected_error
|
||||
print("-" * 100)
|
||||
print(
|
||||
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
||||
@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
result = session.load_lora_adapter(
|
||||
lora_name=data,
|
||||
lora_path=data,
|
||||
expected_error=expected_error,
|
||||
)
|
||||
elif op_type == OperationType.UNLOAD:
|
||||
result = session.unload_lora_adapter(
|
||||
@@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
||||
)
|
||||
elif op_type == OperationType.FORWARD:
|
||||
prompts, adapters = zip(*data)
|
||||
result = session.forward(
|
||||
prompts=list(prompts),
|
||||
lora_paths=list(adapters),
|
||||
max_new_tokens=max_new_tokens,
|
||||
)
|
||||
forward_outputs.append(result)
|
||||
elif op_type == OperationType.EXPECT_ERROR:
|
||||
input_data, expected_error = data
|
||||
prompts, adapters = zip(*input_data)
|
||||
result = session.forward(
|
||||
prompts=list(prompts),
|
||||
lora_paths=list(adapters),
|
||||
max_new_tokens=max_new_tokens,
|
||||
expected_error=expected_error,
|
||||
)
|
||||
if not expected_error:
|
||||
forward_outputs.append(result)
|
||||
|
||||
return forward_outputs
|
||||
|
||||
def test_dynamic_adapter_updates(self):
|
||||
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
||||
for mode in [
|
||||
LoRAUpdateTestSessionMode.ENGINE,
|
||||
LoRAUpdateTestSessionMode.SERVER,
|
||||
]:
|
||||
print("=" * 100)
|
||||
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
||||
print("=" * 100)
|
||||
def _run_dynamic_adapter_updates(
|
||||
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
|
||||
):
|
||||
for case_idx, test_case in enumerate(test_cases, start=1):
|
||||
print("=" * 100)
|
||||
print(
|
||||
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
|
||||
)
|
||||
print("=" * 100)
|
||||
|
||||
print(
|
||||
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||
)
|
||||
# Test dynamic loading of adapters
|
||||
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
|
||||
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
|
||||
dynamic_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.initial_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=test_case.op_sequence,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
print(
|
||||
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||
)
|
||||
# Test dynamic loading of adapters
|
||||
dynamic_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.initial_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=test_case.op_sequence,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
max_lora_rank=test_case.max_lora_rank,
|
||||
lora_target_modules=test_case.lora_target_modules,
|
||||
)
|
||||
|
||||
# static loading
|
||||
forward_ops = [
|
||||
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
|
||||
]
|
||||
# static loading
|
||||
forward_ops = [
|
||||
x
|
||||
for x in test_case.op_sequence
|
||||
if x.type == OperationType.FORWARD and x.expected_error is None
|
||||
]
|
||||
|
||||
print("=" * 100)
|
||||
print(
|
||||
f"\n--- Running static pass with {len(forward_ops)} operations ---"
|
||||
)
|
||||
static_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.all_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=forward_ops,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
print("=" * 100)
|
||||
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
||||
static_output = self._run_operation_sequence(
|
||||
mode=mode,
|
||||
initial_adapters=test_case.all_adapters,
|
||||
base=test_case.base,
|
||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||
op_sequence=forward_ops,
|
||||
max_new_tokens=test_case.max_new_tokens,
|
||||
)
|
||||
|
||||
print(f"Dynamic output: {dynamic_output}")
|
||||
print(f"Static output: {static_output}")
|
||||
print("=" * 100)
|
||||
print(f"Dynamic output: {dynamic_output}")
|
||||
print(f"Static output: {static_output}")
|
||||
print("=" * 100)
|
||||
self.assertEqual(
|
||||
len(dynamic_output),
|
||||
len(static_output),
|
||||
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||
)
|
||||
for i, (dynamic, static) in enumerate(
|
||||
zip(dynamic_output, static_output), start=1
|
||||
):
|
||||
self.assertEqual(
|
||||
len(dynamic_output),
|
||||
len(static_output),
|
||||
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||
len(dynamic),
|
||||
len(static),
|
||||
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||
)
|
||||
for i, (dynamic, static) in enumerate(
|
||||
zip(dynamic_output, static_output), start=1
|
||||
):
|
||||
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||
d_out = d_out.strip()
|
||||
s_out = s_out.strip()
|
||||
self.assertEqual(
|
||||
len(dynamic),
|
||||
len(static),
|
||||
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||
d_out,
|
||||
s_out,
|
||||
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||
)
|
||||
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||
d_out = d_out.strip()
|
||||
s_out = s_out.strip()
|
||||
self.assertEqual(
|
||||
d_out,
|
||||
s_out,
|
||||
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||
)
|
||||
|
||||
def test_dynamic_lora_update_engine(self):
|
||||
"""
|
||||
Test dynamic LoRA updates in engine mode.
|
||||
"""
|
||||
test_cases = ALL_TESTS
|
||||
self._run_dynamic_adapter_updates(
|
||||
mode=LoRAUpdateTestSessionMode.ENGINE,
|
||||
test_cases=test_cases,
|
||||
)
|
||||
|
||||
def test_dynamic_lora_update_server(self):
|
||||
"""
|
||||
Test dynamic LoRA updates in server mode.
|
||||
"""
|
||||
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
|
||||
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
|
||||
|
||||
self._run_dynamic_adapter_updates(
|
||||
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,7 +17,7 @@ suites = {
|
||||
TestFile("models/lora/test_lora_backend.py", 99),
|
||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||
TestFile("models/lora/test_lora_update.py", 400),
|
||||
TestFile("models/lora/test_lora_update.py", 700),
|
||||
TestFile("models/test_embedding_models.py", 73),
|
||||
# TestFile("models/test_clip_models.py", 52),
|
||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||
|
||||
Reference in New Issue
Block a user