Refactor dynamic LoRA update to fix incorrect handling of variant weight shapes (#7844)
This commit is contained in:
@@ -33,6 +33,10 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
"* `lora_backend`: The backend of running GEMM kernels for Lora modules. It can be one of `triton` or `flashinfer`, and set to `triton` by default. For better performance and stability, we recommend using the Triton LoRA backend. In the future, faster backend built upon Cutlass or Cuda kernels will be added.\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
||||||
|
"\n",
|
||||||
|
"* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n",
|
||||||
|
"\n",
|
||||||
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
|
"* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
|
"From client side, the user needs to provide a list of strings as input batch, and a list of adaptor names that each input sequence corresponds to."
|
||||||
@@ -176,6 +180,241 @@
|
|||||||
"terminate_process(server_process)"
|
"terminate_process(server_process)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Dynamic LoRA loading"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Basic Usage\n",
|
||||||
|
"\n",
|
||||||
|
"Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
|
||||||
|
"\n",
|
||||||
|
"(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"server_process, port = launch_server_cmd(\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
|
" --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n",
|
||||||
|
" --cuda-graph-max-bs 2 \\\n",
|
||||||
|
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
||||||
|
" --disable-radix-cache\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||||
|
"wait_for_server(url)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/load_lora_adapter\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"lora_name\": \"lora1\",\n",
|
||||||
|
" \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"if response.status_code == 200:\n",
|
||||||
|
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/generate\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"text\": [\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||||
|
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
|
||||||
|
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/unload_lora_adapter\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"lora_name\": \"lora0\",\n",
|
||||||
|
" },\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/load_lora_adapter\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"lora_name\": \"lora2\",\n",
|
||||||
|
" \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"if response.status_code == 200:\n",
|
||||||
|
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/generate\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"text\": [\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||||
|
" \"lora_path\": [\"lora1\", \"lora2\"],\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"print(f\"Output from lora1: {response.json()[0]['text']}\")\n",
|
||||||
|
"print(f\"Output from lora2: {response.json()[1]['text']}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"### Advanced: hosting adapters of different shapes\n",
|
||||||
|
"\n",
|
||||||
|
"In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n",
|
||||||
|
"\n",
|
||||||
|
"For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n",
|
||||||
|
"lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n",
|
||||||
|
"# We are adding it here just to demonstrate usage.\n",
|
||||||
|
"server_process, port = launch_server_cmd(\n",
|
||||||
|
" f\"\"\"\n",
|
||||||
|
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
|
" --lora-paths lora0={lora0} \\\n",
|
||||||
|
" --cuda-graph-max-bs 2 \\\n",
|
||||||
|
" --max-loras-per-batch 2 --lora-backend triton \\\n",
|
||||||
|
" --disable-radix-cache\n",
|
||||||
|
" --max-lora-rank 64\n",
|
||||||
|
" --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||||
|
"wait_for_server(url)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/load_lora_adapter\",\n",
|
||||||
|
" json={\n",
|
||||||
|
" \"lora_name\": \"lora1\",\n",
|
||||||
|
" \"lora_path\": lora1,\n",
|
||||||
|
" },\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"if response.status_code == 200:\n",
|
||||||
|
" print(\"LoRA adapter loaded successfully.\", response.json())\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(\"Failed to load LoRA adapter.\", response.json())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||||
|
"json_data = {\n",
|
||||||
|
" \"text\": [\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" \"AI is a field of computer science focused on\",\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||||
|
" # The first input uses lora0, and the second input uses lora1\n",
|
||||||
|
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
|
||||||
|
"}\n",
|
||||||
|
"response = requests.post(\n",
|
||||||
|
" url + \"/generate\",\n",
|
||||||
|
" json=json_data,\n",
|
||||||
|
")\n",
|
||||||
|
"print(f\"Output from lora0: {response.json()[0]['text']}\")\n",
|
||||||
|
"print(f\"Output from lora1: {response.json()[1]['text']}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
|
|||||||
@@ -167,6 +167,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
||||||
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
||||||
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
||||||
|
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
|
||||||
|
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. | None |
|
||||||
|
|
||||||
## Kernel backend
|
## Kernel backend
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# and "Punica: Multi-Tenant LoRA Serving"
|
# and "Punica: Multi-Tenant LoRA Serving"
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Set, Tuple
|
from typing import Dict, Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -53,6 +53,8 @@ class LoRAManager:
|
|||||||
lora_backend: str = "triton",
|
lora_backend: str = "triton",
|
||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
|
max_lora_rank: Optional[int] = None,
|
||||||
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
):
|
):
|
||||||
self.base_model: torch.nn.Module = base_model
|
self.base_model: torch.nn.Module = base_model
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
@@ -62,6 +64,10 @@ class LoRAManager:
|
|||||||
self.device: torch.device = next(self.base_model.parameters()).device
|
self.device: torch.device = next(self.base_model.parameters()).device
|
||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
|
self.max_lora_rank: Optional[int] = max_lora_rank
|
||||||
|
self.target_modules: Optional[Set[str]] = (
|
||||||
|
set(target_modules) if target_modules else None
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA backend for running sgemm kernels
|
# LoRA backend for running sgemm kernels
|
||||||
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
||||||
@@ -153,7 +159,9 @@ class LoRAManager:
|
|||||||
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.configs[lora_name] = LoRAConfig(lora_path)
|
new_adapter = LoRAConfig(lora_path)
|
||||||
|
self.validate_new_adapter(lora_name, new_adapter)
|
||||||
|
self.configs[lora_name] = new_adapter
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
success = False
|
success = False
|
||||||
error_message = (
|
error_message = (
|
||||||
@@ -168,6 +176,21 @@ class LoRAManager:
|
|||||||
error_message=error_message,
|
error_message=error_message,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
||||||
|
"""
|
||||||
|
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
incompatible = self.memory_pool and not self.memory_pool.can_support(
|
||||||
|
lora_config
|
||||||
|
)
|
||||||
|
if incompatible:
|
||||||
|
raise ValueError(
|
||||||
|
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration."
|
||||||
|
"We are still working on supporting dynamically updating LoRA shapes. If you expect to use adapters of different shapes, "
|
||||||
|
"You can specify expected configs via --max_lora_rank and --enable_lora_modules."
|
||||||
|
)
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
||||||
"""
|
"""
|
||||||
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
||||||
@@ -214,7 +237,7 @@ class LoRAManager:
|
|||||||
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
||||||
if lora_path is not None:
|
if lora_path is not None:
|
||||||
lora = self.loras[lora_path]
|
lora = self.loras[lora_path]
|
||||||
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
lora_ranks[weight_indices[i]] = lora.config.r
|
||||||
scalings[weight_indices[i]] = lora.scaling
|
scalings[weight_indices[i]] = lora.scaling
|
||||||
|
|
||||||
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
||||||
@@ -319,7 +342,7 @@ class LoRAManager:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
weight_name = get_weight_name(
|
weight_name = get_weight_name(
|
||||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
||||||
)
|
)
|
||||||
module.set_lora_info(
|
module.set_lora_info(
|
||||||
self.memory_pool.get_tensor(
|
self.memory_pool.get_tensor(
|
||||||
@@ -351,58 +374,67 @@ class LoRAManager:
|
|||||||
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
i: {} for i in range(self.base_hf_config.num_hidden_layers)
|
||||||
}
|
}
|
||||||
|
|
||||||
# Initialize memory pool
|
# The LoRA memory pool that manages the GPU buffers for active LoRA weights.
|
||||||
self.memory_pool = LoRAMemoryPool(
|
# It is initialized lazily when the first LoRA adapter is loaded.
|
||||||
self.base_hf_config,
|
self.memory_pool: Optional[LoRAMemoryPool] = None
|
||||||
self.max_loras_per_batch,
|
|
||||||
self.dtype,
|
|
||||||
self.tp_size,
|
|
||||||
self.tp_rank,
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_state_from_configs(self):
|
def update_state_from_configs(self):
|
||||||
"""
|
"""
|
||||||
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
||||||
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
||||||
|
|
||||||
This includes:
|
|
||||||
- Initializing LoRA adapters if they are not already loaded.
|
|
||||||
- Collect all LoRA weight names based on the current loaded adapters.
|
|
||||||
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
|
||||||
- Preparing the GPU buffer pool for active LoRA weights.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Target module names in huggingface lora configs.
|
|
||||||
# e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
|
||||||
hf_target_module_names: Set[str] = set()
|
|
||||||
for config in self.configs.values():
|
|
||||||
hf_target_module_names.update(config.target_modules)
|
|
||||||
max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
|
||||||
|
|
||||||
# Loads / unloads LoRA adapters based on the latest configs.
|
# Loads / unloads LoRA adapters based on the latest configs.
|
||||||
self.update_lora_adapters()
|
self.update_lora_adapters()
|
||||||
|
# Apply the latest LoRA configurations to the internal state for inferencing.
|
||||||
|
self.apply_lora_configs()
|
||||||
|
|
||||||
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
def apply_lora_configs(self):
|
||||||
#
|
"""
|
||||||
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
Apply the LoRA configurations to the base model and internal states of the LoRAManager for inferencing.
|
||||||
# multiple places to support the new weight names when the first adapter targeting such weight names
|
|
||||||
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
|
||||||
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
|
||||||
# list of LoRA weight names is expected to be extremely finite and stable.
|
|
||||||
self.update_lora_weight_names(hf_target_module_names)
|
|
||||||
self.update_lora_modules(hf_target_module_names)
|
|
||||||
self.update_memory_buffers(max_lora_dim)
|
|
||||||
|
|
||||||
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
Notes:
|
||||||
|
- Currently, this method is effectively only invoked during the initialization phase of the LoRAManager as
|
||||||
|
we do not yet support dynamically updating adapter shape configs, which has a dependency on (1) FlashInfer
|
||||||
|
LoRA backend deprecation and (2) CUDA graph recapture support. We are targeting completing these work in
|
||||||
|
early CY25H2.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if self.memory_pool is None:
|
||||||
|
# Infer max_lora_rank and target_modules if not explicitly specified in server args.
|
||||||
|
if self.target_modules is None:
|
||||||
|
self.target_modules = set()
|
||||||
|
for config in self.configs.values():
|
||||||
|
self.target_modules.update(config.target_modules)
|
||||||
|
|
||||||
|
if self.max_lora_rank is None:
|
||||||
|
self.max_lora_rank = max(
|
||||||
|
[x.hf_config["r"] for x in self.configs.values()],
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.update_lora_weight_names()
|
||||||
|
self.update_lora_modules()
|
||||||
|
self.update_memory_buffers()
|
||||||
|
else:
|
||||||
|
# No-op if the memory pool can support the current LoRA configurations.
|
||||||
|
# TODO (lifuhuang): support reinitializing the memory pool when the maximum LoRA rank or target
|
||||||
|
# module is changed once FlashInfer backend is deprecated.
|
||||||
|
assert self.memory_pool.can_support(self.configs.values()), (
|
||||||
|
"LoRA memory pool cannot support the current LoRA configuration. "
|
||||||
|
"This should never happen as we should have validated adapter compatibility. "
|
||||||
|
"Please create a Github issue to report.",
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_lora_weight_names(self):
|
||||||
"""
|
"""
|
||||||
Add new LoRA weight names if needed based on the current `self.configs`.
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Target lora weight names for lora_a and lora_b modules respectively.
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
||||||
for module in hf_target_names:
|
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
||||||
lora_A, lora_B = get_normalized_lora_weight_names(module)
|
self.lora_weight_names[0].update(lora_A)
|
||||||
self.lora_weight_names[0].update(lora_A)
|
self.lora_weight_names[1].update(lora_B)
|
||||||
self.lora_weight_names[1].update(lora_B)
|
|
||||||
|
|
||||||
def update_lora_adapters(self):
|
def update_lora_adapters(self):
|
||||||
"""
|
"""
|
||||||
@@ -434,21 +466,23 @@ class LoRAManager:
|
|||||||
# Additional checks for flashinfer backend
|
# Additional checks for flashinfer backend
|
||||||
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
||||||
if self.lora_backend == "flashinfer":
|
if self.lora_backend == "flashinfer":
|
||||||
lora_dims = set(x.hf_config["r"] for x in self.configs.values())
|
lora_dims = set(x.r for x in self.configs.values())
|
||||||
scalings = set(x.scaling for x in self.loras.values())
|
scalings = set(x.scaling for x in self.loras.values())
|
||||||
assert (
|
assert (
|
||||||
len(lora_dims) == 1 and len(scalings) == 1
|
len(lora_dims) == 1 and len(scalings) == 1
|
||||||
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
||||||
|
|
||||||
def update_memory_buffers(self, max_lora_dim: int):
|
def update_memory_buffers(self):
|
||||||
"""
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
||||||
Update the LoRA memory pool buffers based on the current LoRA configurations and update
|
self.memory_pool = LoRAMemoryPool(
|
||||||
LoRA modules to use the new buffers. This method should be called after the LoRA configurations
|
base_hf_config=self.base_hf_config,
|
||||||
are set or updated.
|
max_loras_per_batch=self.max_loras_per_batch,
|
||||||
"""
|
dtype=self.dtype,
|
||||||
|
tp_size=self.tp_size,
|
||||||
self.memory_pool.init_buffers(
|
tp_rank=self.tp_rank,
|
||||||
self.lora_weight_names, self.base_model, max_lora_dim
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
lora_weight_names=self.lora_weight_names,
|
||||||
|
base_model=self.base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
def set_lora_module(self, module_name, module):
|
def set_lora_module(self, module_name, module):
|
||||||
@@ -456,11 +490,11 @@ class LoRAManager:
|
|||||||
replace_submodule(self.base_model, module_name, lora_module)
|
replace_submodule(self.base_model, module_name, lora_module)
|
||||||
return lora_module
|
return lora_module
|
||||||
|
|
||||||
def update_lora_modules(self, hf_target_names: Set[str]):
|
def update_lora_modules(self):
|
||||||
# Target module names of customized layers defined in python/sglang/srt/layers
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
||||||
# e.g., {"qkv_proj", "o_proj"}
|
# e.g., {"qkv_proj", "o_proj"}
|
||||||
customized_target_names = get_customized_names_from_hf_names(
|
customized_target_names = get_customized_names_from_hf_names(
|
||||||
hf_target_names, self.base_model
|
self.target_modules, self.base_model
|
||||||
)
|
)
|
||||||
|
|
||||||
for module_name, module in self.base_model.named_modules():
|
for module_name, module in self.base_model.named_modules():
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, Dict, List, Optional, Set, Tuple
|
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -6,10 +6,12 @@ from sglang.srt.distributed import divide
|
|||||||
from sglang.srt.hf_transformers_utils import AutoConfig
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
||||||
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
||||||
from sglang.srt.lora.lora import LoRAAdapter
|
from sglang.srt.lora.lora import LoRAAdapter
|
||||||
|
from sglang.srt.lora.lora_config import LoRAConfig
|
||||||
from sglang.srt.lora.utils import (
|
from sglang.srt.lora.utils import (
|
||||||
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
||||||
LoRAType,
|
LoRAType,
|
||||||
get_hidden_dim,
|
get_hidden_dim,
|
||||||
|
get_normalized_lora_weight_names,
|
||||||
get_stacked_multiply,
|
get_stacked_multiply,
|
||||||
get_weight_name,
|
get_weight_name,
|
||||||
)
|
)
|
||||||
@@ -25,6 +27,9 @@ class LoRAMemoryPool:
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
tp_size: int,
|
tp_size: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
max_lora_rank: int,
|
||||||
|
lora_weight_names: Tuple[Set[str], Set[str]],
|
||||||
|
base_model: torch.nn.Module,
|
||||||
):
|
):
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
self.num_layer: int = base_hf_config.num_hidden_layers
|
self.num_layer: int = base_hf_config.num_hidden_layers
|
||||||
@@ -32,6 +37,10 @@ class LoRAMemoryPool:
|
|||||||
self.dtype: torch.dtype = dtype
|
self.dtype: torch.dtype = dtype
|
||||||
self.tp_size: int = tp_size
|
self.tp_size: int = tp_size
|
||||||
self.tp_rank: int = tp_rank
|
self.tp_rank: int = tp_rank
|
||||||
|
self.max_lora_rank: int = max_lora_rank
|
||||||
|
|
||||||
|
# lora weight names for LoRA A and B respectively.
|
||||||
|
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
||||||
|
|
||||||
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
||||||
# A_buffer contains num_layer number of row-major tensors with shape
|
# A_buffer contains num_layer number of row-major tensors with shape
|
||||||
@@ -49,6 +58,31 @@ class LoRAMemoryPool:
|
|||||||
# Here we don't initialize to None since None is a valid uid
|
# Here we don't initialize to None since None is a valid uid
|
||||||
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
||||||
|
|
||||||
|
self.init_buffers(base_model)
|
||||||
|
|
||||||
|
def can_support(self, config: Union[LoRAConfig, Iterable[LoRAConfig]]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the memory pool can support the given LoRA adapters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _can_support(config: LoRAConfig) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the memory pool can support a single LoRA adapter.
|
||||||
|
"""
|
||||||
|
if config.r > self.max_lora_rank:
|
||||||
|
return False
|
||||||
|
weights_a, weights_b = get_normalized_lora_weight_names(
|
||||||
|
config.target_modules
|
||||||
|
)
|
||||||
|
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
||||||
|
self.lora_weight_names[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(config, LoRAConfig):
|
||||||
|
return _can_support(config)
|
||||||
|
else:
|
||||||
|
return all(_can_support(x) for x in config)
|
||||||
|
|
||||||
def get_lora_A_shape(
|
def get_lora_A_shape(
|
||||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||||
) -> Tuple[int]:
|
) -> Tuple[int]:
|
||||||
@@ -82,25 +116,18 @@ class LoRAMemoryPool:
|
|||||||
max_lora_dim,
|
max_lora_dim,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_buffers(
|
def init_buffers(self, base_model: torch.nn.Module):
|
||||||
self,
|
|
||||||
lora_weight_names: Tuple[Set[str]],
|
|
||||||
base_model: torch.nn.Module,
|
|
||||||
max_lora_dim: int,
|
|
||||||
):
|
|
||||||
# lora_weight_names is a set of name pairs indicating each pair of lora modules to load
|
|
||||||
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj"), ("o_proj", "o_proj")}
|
|
||||||
self.lora_weight_names: Tuple[Set[str]] = lora_weight_names
|
|
||||||
device = next(base_model.parameters()).device
|
device = next(base_model.parameters()).device
|
||||||
|
|
||||||
def update_buffer(
|
def init_buffer(
|
||||||
buffer: Dict[str, List[torch.Tensor]],
|
buffer: Dict[str, List[torch.Tensor]],
|
||||||
lora_weight_names: Set[str],
|
lora_weight_names: Set[str],
|
||||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||||
):
|
):
|
||||||
new_weight_names = lora_weight_names - buffer.keys()
|
for module_name in lora_weight_names:
|
||||||
for module_name in new_weight_names:
|
lora_shape = get_lora_shape_fn(
|
||||||
lora_shape = get_lora_shape_fn(module_name, base_model, max_lora_dim)
|
module_name, base_model, self.max_lora_rank
|
||||||
|
)
|
||||||
buffer[module_name] = [
|
buffer[module_name] = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
lora_shape,
|
lora_shape,
|
||||||
@@ -110,15 +137,15 @@ class LoRAMemoryPool:
|
|||||||
for _ in range(self.num_layer)
|
for _ in range(self.num_layer)
|
||||||
]
|
]
|
||||||
|
|
||||||
update_buffer(
|
init_buffer(
|
||||||
self.A_buffer,
|
self.A_buffer,
|
||||||
lora_weight_names[0],
|
self.lora_weight_names[0],
|
||||||
self.get_lora_A_shape,
|
self.get_lora_A_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
update_buffer(
|
init_buffer(
|
||||||
self.B_buffer,
|
self.B_buffer,
|
||||||
lora_weight_names[1],
|
self.lora_weight_names[1],
|
||||||
self.get_lora_B_shape,
|
self.get_lora_B_shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Optional, Set, Tuple
|
from typing import Iterable, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -106,9 +106,11 @@ def get_hidden_dim(
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
def get_normalized_lora_weight_names(
|
||||||
|
target_modules: Iterable[str],
|
||||||
|
) -> Tuple[set[str], set[str]]:
|
||||||
"""
|
"""
|
||||||
Mapping a target module name to names of the normalized LoRA weights.
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||||
Returned tuple contains (name for Lora A, name for Lora B)
|
Returned tuple contains (name for Lora A, name for Lora B)
|
||||||
"""
|
"""
|
||||||
params_mapping = {
|
params_mapping = {
|
||||||
@@ -120,8 +122,13 @@ def get_normalized_lora_weight_names(name: str) -> Tuple[List[str], List[str]]:
|
|||||||
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
||||||
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
||||||
}
|
}
|
||||||
stacked = params_mapping.get(name, ([name], [name]))
|
|
||||||
return stacked
|
result = (set(), set())
|
||||||
|
for name in target_modules:
|
||||||
|
lora_a, lora_b = params_mapping.get(name, ([name], [name]))
|
||||||
|
result[0].update(lora_a)
|
||||||
|
result[1].update(lora_b)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_stacked_multiply(module_name: str) -> int:
|
def get_stacked_multiply(module_name: str) -> int:
|
||||||
|
|||||||
@@ -891,6 +891,8 @@ class ModelRunner:
|
|||||||
lora_backend=self.server_args.lora_backend,
|
lora_backend=self.server_args.lora_backend,
|
||||||
tp_size=self.tp_size,
|
tp_size=self.tp_size,
|
||||||
tp_rank=self.tp_rank,
|
tp_rank=self.tp_rank,
|
||||||
|
max_lora_rank=self.server_args.max_lora_rank,
|
||||||
|
target_modules=self.server_args.lora_target_modules,
|
||||||
)
|
)
|
||||||
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
||||||
if result.success:
|
if result.success:
|
||||||
|
|||||||
@@ -134,6 +134,8 @@ class ServerArgs:
|
|||||||
preferred_sampling_params: Optional[str] = None
|
preferred_sampling_params: Optional[str] = None
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
|
max_lora_rank: Optional[int] = None
|
||||||
|
lora_target_modules: Optional[List[str]] = None
|
||||||
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
@@ -1129,6 +1131,28 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-lora-rank",
|
||||||
|
default=ServerArgs.max_lora_rank,
|
||||||
|
type=int,
|
||||||
|
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lora-target-modules",
|
||||||
|
type=str,
|
||||||
|
choices=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
|
nargs="*",
|
||||||
|
default=None,
|
||||||
|
help="The union set of all target modules where LoRA should be applied. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lora-paths",
|
"--lora-paths",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
@@ -505,6 +505,8 @@ class SRTRunner:
|
|||||||
torchao_config: Optional[str] = None,
|
torchao_config: Optional[str] = None,
|
||||||
cuda_graph_max_bs: int = 4,
|
cuda_graph_max_bs: int = 4,
|
||||||
sleep_on_idle=False,
|
sleep_on_idle=False,
|
||||||
|
max_lora_rank: Optional[int] = None,
|
||||||
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.is_generation = model_type == "generation"
|
self.is_generation = model_type == "generation"
|
||||||
@@ -543,6 +545,8 @@ class SRTRunner:
|
|||||||
cuda_graph_max_bs=cuda_graph_max_bs,
|
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||||
sleep_on_idle=sleep_on_idle,
|
sleep_on_idle=sleep_on_idle,
|
||||||
|
max_lora_rank=max_lora_rank,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
**spec_kwargs,
|
**spec_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import multiprocessing as mp
|
|||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
@@ -27,6 +27,7 @@ from sglang.test.test_utils import (
|
|||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
|
is_in_ci,
|
||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,24 +46,28 @@ class OperationType(Enum):
|
|||||||
LOAD = "load"
|
LOAD = "load"
|
||||||
UNLOAD = "unload"
|
UNLOAD = "unload"
|
||||||
FORWARD = "forward"
|
FORWARD = "forward"
|
||||||
EXPECT_ERROR = "expect_error"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Operation:
|
class Operation:
|
||||||
# Operation type, can be LOAD, UNLOAD, FORWARD, or EXPECT_ERROR
|
# Operation type, can be LOAD, UNLOAD, FORWARD
|
||||||
type: OperationType
|
type: OperationType
|
||||||
# Data associated with the operation. Exact type varies depending on the operation
|
# Data associated with the operation. Exact type varies depending on the operation
|
||||||
data: Optional[Any]
|
data: Optional[Any]
|
||||||
|
# If the operation is expected to fail, this is the error message to expect
|
||||||
|
expected_error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TestCase:
|
class TestCase:
|
||||||
|
description: str
|
||||||
base: str
|
base: str
|
||||||
max_loras_per_batch: int
|
max_loras_per_batch: int
|
||||||
all_adapters: List[str]
|
all_adapters: List[str]
|
||||||
initial_adapters: List[str]
|
initial_adapters: List[str]
|
||||||
op_sequence: List[Operation]
|
op_sequence: List[Operation]
|
||||||
|
max_lora_rank: Optional[int] = None
|
||||||
|
lora_target_modules: Optional[List] = None
|
||||||
max_new_tokens: int = 32
|
max_new_tokens: int = 32
|
||||||
|
|
||||||
|
|
||||||
@@ -72,9 +77,9 @@ def create_batch_data(adapters: Union[str, list]) -> List[tuple[str, str]]:
|
|||||||
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
return [(prompt, adapter) for prompt in PROMPTS for adapter in adapters]
|
||||||
|
|
||||||
|
|
||||||
TEST_CASES = [
|
BASIC_TESTS = [
|
||||||
# basic test, no eviction
|
|
||||||
TestCase(
|
TestCase(
|
||||||
|
description="dynamic lora update with initial lora_paths",
|
||||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
max_loras_per_batch=3,
|
max_loras_per_batch=3,
|
||||||
all_adapters=[
|
all_adapters=[
|
||||||
@@ -89,20 +94,16 @@ TEST_CASES = [
|
|||||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data(
|
||||||
create_batch_data(
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
||||||
),
|
|
||||||
"not loaded",
|
|
||||||
),
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
expected_error="not loaded",
|
||||||
"not loaded",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.LOAD,
|
type=OperationType.LOAD,
|
||||||
@@ -127,11 +128,9 @@ TEST_CASES = [
|
|||||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
expected_error="not loaded",
|
||||||
"not loaded",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
@@ -147,13 +146,11 @@ TEST_CASES = [
|
|||||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data(
|
||||||
create_batch_data(
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
||||||
),
|
|
||||||
"not loaded",
|
|
||||||
),
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
@@ -174,8 +171,8 @@ TEST_CASES = [
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
# Eviction
|
|
||||||
TestCase(
|
TestCase(
|
||||||
|
description="dynamic lora update with evictions",
|
||||||
base="meta-llama/Llama-3.1-8B-Instruct",
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
max_loras_per_batch=1,
|
max_loras_per_batch=1,
|
||||||
all_adapters=[
|
all_adapters=[
|
||||||
@@ -190,20 +187,16 @@ TEST_CASES = [
|
|||||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data(
|
||||||
create_batch_data(
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
|
||||||
),
|
|
||||||
"not loaded",
|
|
||||||
),
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
expected_error="not loaded",
|
||||||
"not loaded",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.LOAD,
|
type=OperationType.LOAD,
|
||||||
@@ -214,11 +207,9 @@ TEST_CASES = [
|
|||||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.EXPECT_ERROR,
|
type=OperationType.FORWARD,
|
||||||
data=(
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
expected_error="not loaded",
|
||||||
"not loaded",
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
@@ -263,6 +254,253 @@ TEST_CASES = [
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
TARGET_MODULE_TESTS = [
|
||||||
|
TestCase(
|
||||||
|
description="Test explicitly specified lora-target-modules.",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
lora_target_modules=[
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
"o_proj",
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
"down_proj",
|
||||||
|
],
|
||||||
|
max_lora_rank=64,
|
||||||
|
all_adapters=[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||||
|
],
|
||||||
|
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
description="Test inferred lora-target-modules - start with larger adapter",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
max_lora_rank=64,
|
||||||
|
all_adapters=[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||||
|
],
|
||||||
|
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||||
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
description="Test inferred lora-target-modules - start with smaller adapter",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
max_lora_rank=64,
|
||||||
|
all_adapters=[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # target_modules = q, k, v, o, gate, up, down
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora", # target_modules = q, k, v, o, gate
|
||||||
|
],
|
||||||
|
initial_adapters=["algoprog/fact-generation-llama-3.1-8b-instruct-lora"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
expected_error="updating LoRA shapes",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"algoprog/fact-generation-llama-3.1-8b-instruct-lora",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
MAX_LORA_RANK_TESTS = [
|
||||||
|
TestCase(
|
||||||
|
description="Test explicitly specified max-lora-rank.",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
max_lora_rank=32,
|
||||||
|
all_adapters=[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||||
|
],
|
||||||
|
initial_adapters=["Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
expected_error="updating LoRA shapes",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
description="test implicitly inferred max-lora-rank",
|
||||||
|
base="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
max_loras_per_batch=3,
|
||||||
|
all_adapters=[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16", # r = 4
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction", # r = 32
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora", # r = 256
|
||||||
|
],
|
||||||
|
initial_adapters=["pbevan11/llama-3.1-8b-ocr-correction"],
|
||||||
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
expected_error="updating LoRA shapes",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
|
expected_error="not loaded",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.LOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
None,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
ALL_TESTS = BASIC_TESTS + TARGET_MODULE_TESTS + MAX_LORA_RANK_TESTS
|
||||||
|
|
||||||
|
|
||||||
class LoRAUpdateTestSessionMode(Enum):
|
class LoRAUpdateTestSessionMode(Enum):
|
||||||
@@ -281,7 +519,9 @@ class LoRAUpdateTestSessionBase:
|
|||||||
testcase: Optional[TestCase],
|
testcase: Optional[TestCase],
|
||||||
model_path: str,
|
model_path: str,
|
||||||
lora_paths: list[str],
|
lora_paths: list[str],
|
||||||
max_loras_per_batch: int = 1,
|
max_loras_per_batch: int,
|
||||||
|
max_lora_rank: Optional[int],
|
||||||
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
lora_backend: str = "triton",
|
lora_backend: str = "triton",
|
||||||
disable_cuda_graph: bool = False,
|
disable_cuda_graph: bool = False,
|
||||||
cuda_graph_max_bs: int = 4,
|
cuda_graph_max_bs: int = 4,
|
||||||
@@ -289,6 +529,8 @@ class LoRAUpdateTestSessionBase:
|
|||||||
self.testcase = testcase
|
self.testcase = testcase
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self.lora_paths = lora_paths
|
self.lora_paths = lora_paths
|
||||||
|
self.max_lora_rank = max_lora_rank
|
||||||
|
self.lora_target_modules = lora_target_modules
|
||||||
self.max_loras_per_batch = max_loras_per_batch
|
self.max_loras_per_batch = max_loras_per_batch
|
||||||
self.lora_backend = lora_backend
|
self.lora_backend = lora_backend
|
||||||
self.disable_cuda_graph = disable_cuda_graph
|
self.disable_cuda_graph = disable_cuda_graph
|
||||||
@@ -304,7 +546,12 @@ class LoRAUpdateTestSessionBase:
|
|||||||
# Don't suppress exceptions by default
|
# Don't suppress exceptions by default
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
lora_name: str,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
|
expected_error: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load a LoRA adapter by name and path.
|
Load a LoRA adapter by name and path.
|
||||||
"""
|
"""
|
||||||
@@ -321,6 +568,7 @@ class LoRAUpdateTestSessionBase:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
|
expected_error: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
@@ -339,6 +587,8 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|||||||
model_path=self.model_path,
|
model_path=self.model_path,
|
||||||
model_type="generation",
|
model_type="generation",
|
||||||
lora_paths=self.lora_paths,
|
lora_paths=self.lora_paths,
|
||||||
|
max_lora_rank=self.max_lora_rank,
|
||||||
|
lora_target_modules=self.lora_target_modules,
|
||||||
lora_backend=self.lora_backend,
|
lora_backend=self.lora_backend,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
mem_fraction_static=MEM_FRACTION_STATIC,
|
mem_fraction_static=MEM_FRACTION_STATIC,
|
||||||
@@ -357,24 +607,32 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|||||||
# don't suppress exceptions
|
# don't suppress exceptions
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
lora_name: str,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
|
expected_error: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load a LoRA adapter by name and path.
|
Load a LoRA adapter by name and path.
|
||||||
"""
|
"""
|
||||||
if lora_path is None:
|
if lora_path is None:
|
||||||
lora_path = lora_name
|
lora_path = lora_name
|
||||||
|
|
||||||
self.expected_adapters.add(lora_name)
|
|
||||||
|
|
||||||
response = self.handle.load_lora_adapter(
|
response = self.handle.load_lora_adapter(
|
||||||
lora_name=lora_name,
|
lora_name=lora_name,
|
||||||
lora_path=lora_path,
|
lora_path=lora_path,
|
||||||
)
|
)
|
||||||
self.testcase.assertTrue(response.success)
|
if expected_error:
|
||||||
loaded_adapters = set(response.loaded_adapters)
|
self.testcase.assertFalse(response.success)
|
||||||
|
self.testcase.assertIn(expected_error, response.error_message)
|
||||||
print(f"loaded_adapters: {loaded_adapters}")
|
print(f"Received error as expected: {response.error_message}")
|
||||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
else:
|
||||||
|
self.expected_adapters.add(lora_name)
|
||||||
|
self.testcase.assertTrue(response.success)
|
||||||
|
loaded_adapters = set(response.loaded_adapters)
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_name: str):
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
"""
|
"""
|
||||||
@@ -396,7 +654,7 @@ class LoRAUpdateEngineTestSession(LoRAUpdateTestSessionBase):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
expected_error: str = None,
|
expected_error: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
@@ -448,6 +706,10 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
]
|
]
|
||||||
if self.disable_cuda_graph:
|
if self.disable_cuda_graph:
|
||||||
other_args.append("--disable-cuda-graph")
|
other_args.append("--disable-cuda-graph")
|
||||||
|
if self.max_lora_rank is not None:
|
||||||
|
other_args.extend(["--max-lora-rank", str(self.max_lora_rank)])
|
||||||
|
if self.lora_target_modules is not None:
|
||||||
|
other_args.extend(["--lora-target-modules"] + self.lora_target_modules)
|
||||||
|
|
||||||
# launch external server
|
# launch external server
|
||||||
self.handle = popen_launch_server(
|
self.handle = popen_launch_server(
|
||||||
@@ -464,24 +726,32 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
# don't suppress exceptions
|
# don't suppress exceptions
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def load_lora_adapter(self, lora_name: str, lora_path: Optional[str] = None):
|
def load_lora_adapter(
|
||||||
|
self,
|
||||||
|
lora_name: str,
|
||||||
|
lora_path: Optional[str] = None,
|
||||||
|
expected_error: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Load a LoRA adapter by name and path.
|
Load a LoRA adapter by name and path.
|
||||||
"""
|
"""
|
||||||
if lora_path is None:
|
if lora_path is None:
|
||||||
lora_path = lora_name
|
lora_path = lora_name
|
||||||
|
|
||||||
self.expected_adapters.add(lora_name)
|
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
DEFAULT_URL_FOR_TEST + "/load_lora_adapter",
|
||||||
json={"lora_name": lora_name, "lora_path": lora_path},
|
json={"lora_name": lora_name, "lora_path": lora_path},
|
||||||
)
|
)
|
||||||
self.testcase.assertTrue(response.ok)
|
if expected_error:
|
||||||
loaded_adapters = set(response.json()["loaded_adapters"])
|
self.testcase.assertEqual(response.status_code, 400)
|
||||||
|
self.testcase.assertIn(expected_error, response.text)
|
||||||
print(f"loaded_adapters: {loaded_adapters}")
|
print(f"Received error as expected: {response.text}")
|
||||||
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
else:
|
||||||
|
self.expected_adapters.add(lora_name)
|
||||||
|
self.testcase.assertTrue(response.ok)
|
||||||
|
loaded_adapters = set(response.json()["loaded_adapters"])
|
||||||
|
print(f"loaded_adapters: {loaded_adapters}")
|
||||||
|
self.testcase.assertEqual(loaded_adapters, self.expected_adapters)
|
||||||
|
|
||||||
def unload_lora_adapter(self, lora_name: str):
|
def unload_lora_adapter(self, lora_name: str):
|
||||||
"""
|
"""
|
||||||
@@ -504,7 +774,7 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
lora_paths: List[str],
|
lora_paths: List[str],
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
expected_error: str = None,
|
expected_error: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
Perform a batch forward pass with the current set of loaded LoRA adapters.
|
||||||
@@ -537,30 +807,14 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
|
|
||||||
# Factory function to create the appropriate LoRA test session based on mode
|
# Factory function to create the appropriate LoRA test session based on mode
|
||||||
def LoRAUpdateTestSession(
|
def LoRAUpdateTestSession(
|
||||||
*,
|
|
||||||
testcase: Optional[TestCase],
|
testcase: Optional[TestCase],
|
||||||
mode: LoRAUpdateTestSessionMode,
|
mode: LoRAUpdateTestSessionMode,
|
||||||
model_path: str,
|
**kwargs: Any,
|
||||||
lora_paths: list[str],
|
|
||||||
max_loras_per_batch: int = 1,
|
|
||||||
lora_backend: str = "triton",
|
|
||||||
disable_cuda_graph: bool = False,
|
|
||||||
cuda_graph_max_bs: int = 4,
|
|
||||||
):
|
):
|
||||||
common_kwargs = {
|
|
||||||
"testcase": testcase,
|
|
||||||
"model_path": model_path,
|
|
||||||
"lora_paths": lora_paths,
|
|
||||||
"max_loras_per_batch": max_loras_per_batch,
|
|
||||||
"lora_backend": lora_backend,
|
|
||||||
"disable_cuda_graph": disable_cuda_graph,
|
|
||||||
"cuda_graph_max_bs": cuda_graph_max_bs,
|
|
||||||
}
|
|
||||||
|
|
||||||
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
if mode == LoRAUpdateTestSessionMode.ENGINE:
|
||||||
return LoRAUpdateEngineTestSession(**common_kwargs)
|
return LoRAUpdateEngineTestSession(testcase=testcase, **kwargs)
|
||||||
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
elif mode == LoRAUpdateTestSessionMode.SERVER:
|
||||||
return LoRAUpdateServerTestSession(**common_kwargs)
|
return LoRAUpdateServerTestSession(testcase=testcase, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized mode: {mode!r}")
|
raise ValueError(f"Unrecognized mode: {mode!r}")
|
||||||
|
|
||||||
@@ -582,6 +836,8 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
initial_adapters: List[str],
|
initial_adapters: List[str],
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
op_sequence: List[Operation],
|
op_sequence: List[Operation],
|
||||||
|
max_lora_rank: Optional[int] = None,
|
||||||
|
lora_target_modules: Optional[List[str]] = None,
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
) -> List[tuple]:
|
) -> List[tuple]:
|
||||||
"""
|
"""
|
||||||
@@ -596,10 +852,13 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
model_path=base,
|
model_path=base,
|
||||||
lora_paths=initial_adapters,
|
lora_paths=initial_adapters,
|
||||||
max_loras_per_batch=max_loras_per_batch,
|
max_loras_per_batch=max_loras_per_batch,
|
||||||
|
max_lora_rank=max_lora_rank,
|
||||||
|
lora_target_modules=lora_target_modules,
|
||||||
) as session:
|
) as session:
|
||||||
for op in op_sequence:
|
for op in op_sequence:
|
||||||
op_type = op.type
|
op_type = op.type
|
||||||
data = op.data
|
data = op.data
|
||||||
|
expected_error = op.expected_error
|
||||||
print("-" * 100)
|
print("-" * 100)
|
||||||
print(
|
print(
|
||||||
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
f"Running operation: {op_type} --- data: {data} --- mode: {mode} ---"
|
||||||
@@ -608,6 +867,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
result = session.load_lora_adapter(
|
result = session.load_lora_adapter(
|
||||||
lora_name=data,
|
lora_name=data,
|
||||||
lora_path=data,
|
lora_path=data,
|
||||||
|
expected_error=expected_error,
|
||||||
)
|
)
|
||||||
elif op_type == OperationType.UNLOAD:
|
elif op_type == OperationType.UNLOAD:
|
||||||
result = session.unload_lora_adapter(
|
result = session.unload_lora_adapter(
|
||||||
@@ -615,91 +875,105 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
)
|
)
|
||||||
elif op_type == OperationType.FORWARD:
|
elif op_type == OperationType.FORWARD:
|
||||||
prompts, adapters = zip(*data)
|
prompts, adapters = zip(*data)
|
||||||
result = session.forward(
|
|
||||||
prompts=list(prompts),
|
|
||||||
lora_paths=list(adapters),
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
forward_outputs.append(result)
|
|
||||||
elif op_type == OperationType.EXPECT_ERROR:
|
|
||||||
input_data, expected_error = data
|
|
||||||
prompts, adapters = zip(*input_data)
|
|
||||||
result = session.forward(
|
result = session.forward(
|
||||||
prompts=list(prompts),
|
prompts=list(prompts),
|
||||||
lora_paths=list(adapters),
|
lora_paths=list(adapters),
|
||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
expected_error=expected_error,
|
expected_error=expected_error,
|
||||||
)
|
)
|
||||||
|
if not expected_error:
|
||||||
|
forward_outputs.append(result)
|
||||||
|
|
||||||
return forward_outputs
|
return forward_outputs
|
||||||
|
|
||||||
def test_dynamic_adapter_updates(self):
|
def _run_dynamic_adapter_updates(
|
||||||
for case_idx, test_case in enumerate(TEST_CASES, start=1):
|
self, mode: LoRAUpdateTestSessionMode, test_cases: Iterable[TestCase]
|
||||||
for mode in [
|
):
|
||||||
LoRAUpdateTestSessionMode.ENGINE,
|
for case_idx, test_case in enumerate(test_cases, start=1):
|
||||||
LoRAUpdateTestSessionMode.SERVER,
|
print("=" * 100)
|
||||||
]:
|
print(
|
||||||
print("=" * 100)
|
f"Starting test case {case_idx} in {mode.value} mode. Test description: {test_case.description}"
|
||||||
print(f"Starting test case {case_idx} in {mode.value} mode.")
|
)
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
f"--- Running dynamic update pass with {len(test_case.op_sequence)} operations ---"
|
||||||
)
|
)
|
||||||
# Test dynamic loading of adapters
|
# Test dynamic loading of adapters
|
||||||
# TODO (lifuhuang): currently at least one LoRA path is required during initialization to enable lora,
|
dynamic_output = self._run_operation_sequence(
|
||||||
# we should fix this in the future https://github.com/sgl-project/sglang/issues/7463.
|
mode=mode,
|
||||||
dynamic_output = self._run_operation_sequence(
|
initial_adapters=test_case.initial_adapters,
|
||||||
mode=mode,
|
base=test_case.base,
|
||||||
initial_adapters=test_case.initial_adapters,
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||||
base=test_case.base,
|
op_sequence=test_case.op_sequence,
|
||||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
max_new_tokens=test_case.max_new_tokens,
|
||||||
op_sequence=test_case.op_sequence,
|
max_lora_rank=test_case.max_lora_rank,
|
||||||
max_new_tokens=test_case.max_new_tokens,
|
lora_target_modules=test_case.lora_target_modules,
|
||||||
)
|
)
|
||||||
|
|
||||||
# static loading
|
# static loading
|
||||||
forward_ops = [
|
forward_ops = [
|
||||||
x for x in test_case.op_sequence if x.type == OperationType.FORWARD
|
x
|
||||||
]
|
for x in test_case.op_sequence
|
||||||
|
if x.type == OperationType.FORWARD and x.expected_error is None
|
||||||
|
]
|
||||||
|
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
print(
|
print(f"\n--- Running static pass with {len(forward_ops)} operations ---")
|
||||||
f"\n--- Running static pass with {len(forward_ops)} operations ---"
|
static_output = self._run_operation_sequence(
|
||||||
)
|
mode=mode,
|
||||||
static_output = self._run_operation_sequence(
|
initial_adapters=test_case.all_adapters,
|
||||||
mode=mode,
|
base=test_case.base,
|
||||||
initial_adapters=test_case.all_adapters,
|
max_loras_per_batch=test_case.max_loras_per_batch,
|
||||||
base=test_case.base,
|
op_sequence=forward_ops,
|
||||||
max_loras_per_batch=test_case.max_loras_per_batch,
|
max_new_tokens=test_case.max_new_tokens,
|
||||||
op_sequence=forward_ops,
|
)
|
||||||
max_new_tokens=test_case.max_new_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Dynamic output: {dynamic_output}")
|
print(f"Dynamic output: {dynamic_output}")
|
||||||
print(f"Static output: {static_output}")
|
print(f"Static output: {static_output}")
|
||||||
print("=" * 100)
|
print("=" * 100)
|
||||||
|
self.assertEqual(
|
||||||
|
len(dynamic_output),
|
||||||
|
len(static_output),
|
||||||
|
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
||||||
|
)
|
||||||
|
for i, (dynamic, static) in enumerate(
|
||||||
|
zip(dynamic_output, static_output), start=1
|
||||||
|
):
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(dynamic_output),
|
len(dynamic),
|
||||||
len(static_output),
|
len(static),
|
||||||
f"Dynamic output length {len(dynamic_output)} does not match static output length {len(static_output)}",
|
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
||||||
)
|
)
|
||||||
for i, (dynamic, static) in enumerate(
|
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
||||||
zip(dynamic_output, static_output), start=1
|
d_out = d_out.strip()
|
||||||
):
|
s_out = s_out.strip()
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
len(dynamic),
|
d_out,
|
||||||
len(static),
|
s_out,
|
||||||
f"Output length mismatch at batch {i}:\n- Dynamic={len(dynamic)}\n- Static={len(static)}",
|
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
||||||
)
|
)
|
||||||
for j, (d_out, s_out) in enumerate(zip(dynamic, static), start=1):
|
|
||||||
d_out = d_out.strip()
|
def test_dynamic_lora_update_engine(self):
|
||||||
s_out = s_out.strip()
|
"""
|
||||||
self.assertEqual(
|
Test dynamic LoRA updates in engine mode.
|
||||||
d_out,
|
"""
|
||||||
s_out,
|
test_cases = ALL_TESTS
|
||||||
f"Output mismatch at batch {i}, prompt {j}:\n- Dynamic: '{d_out}'\n- Static: '{s_out}'",
|
self._run_dynamic_adapter_updates(
|
||||||
)
|
mode=LoRAUpdateTestSessionMode.ENGINE,
|
||||||
|
test_cases=test_cases,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dynamic_lora_update_server(self):
|
||||||
|
"""
|
||||||
|
Test dynamic LoRA updates in server mode.
|
||||||
|
"""
|
||||||
|
# In CI, we only run the first test case to save time, as the engine test should be mostly sufficient for ensuring correctness.
|
||||||
|
test_cases = BASIC_TESTS if is_in_ci() else ALL_TESTS
|
||||||
|
|
||||||
|
self._run_dynamic_adapter_updates(
|
||||||
|
mode=LoRAUpdateTestSessionMode.SERVER, test_cases=test_cases
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ suites = {
|
|||||||
TestFile("models/lora/test_lora_backend.py", 99),
|
TestFile("models/lora/test_lora_backend.py", 99),
|
||||||
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
TestFile("models/lora/test_multi_lora_backend.py", 60),
|
||||||
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
TestFile("models/lora/test_lora_cuda_graph.py", 250),
|
||||||
TestFile("models/lora/test_lora_update.py", 400),
|
TestFile("models/lora/test_lora_update.py", 700),
|
||||||
TestFile("models/test_embedding_models.py", 73),
|
TestFile("models/test_embedding_models.py", 73),
|
||||||
# TestFile("models/test_clip_models.py", 52),
|
# TestFile("models/test_clip_models.py", 52),
|
||||||
TestFile("models/test_encoder_embedding_models.py", 100),
|
TestFile("models/test_encoder_embedding_models.py", 100),
|
||||||
|
|||||||
Reference in New Issue
Block a user