Support pinning adapter via server args. (#9249)
This commit is contained in:
@@ -29,7 +29,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
|
"* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n",
|
"* `lora_paths`: The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {\"lora_name\":str,\"lora_path\":str,\"pinned\":bool}.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
|
"* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -372,6 +372,15 @@
|
|||||||
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
|
"print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"terminate_process(server_process)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
@@ -387,7 +396,40 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
|
"This can improve performance in scenarios where the same adapter is frequently used across requests, by avoiding repeated memory transfers and reinitialization overhead. However, since GPU pool slots are limited, pinning adapters reduces the flexibility of the system to dynamically load other adapters on demand. If too many adapters are pinned, it may lead to degraded performance, or in the most extreme case (`Number of pinned adapters == max-loras-per-batch`), halt all unpinned requests. Therefore, currently SGLang limits maximal number of pinned adapters to `max-loras-per-batch - 1` to prevent unexpected starvations. \n",
|
||||||
"\n",
|
"\n",
|
||||||
"In the example below, we unload `lora1` and reload it as a `pinned` adapter:"
|
"In the example below, we start a server with `lora1` loaded as pinned, `lora2` and `lora3` loaded as regular (unpinned) adapters. Please note that, we intentionally specify `lora2` and `lora3` in two different formats to demonstrate that both are supported."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"server_process, port = launch_server_cmd(\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
||||||
|
" --enable-lora \\\n",
|
||||||
|
" --cuda-graph-max-bs 8 \\\n",
|
||||||
|
" --max-loras-per-batch 3 --lora-backend triton \\\n",
|
||||||
|
" --max-lora-rank 256 \\\n",
|
||||||
|
" --lora-target-modules all \\\n",
|
||||||
|
" --lora-paths \\\n",
|
||||||
|
" {\"lora_name\":\"lora0\",\"lora_path\":\"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\"pinned\":true} \\\n",
|
||||||
|
" {\"lora_name\":\"lora1\",\"lora_path\":\"algoprog/fact-generation-llama-3.1-8b-instruct-lora\"} \\\n",
|
||||||
|
" lora2=philschmid/code-llama-3-1-8b-text-to-sql-lora\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"url = f\"http://127.0.0.1:{port}\"\n",
|
||||||
|
"wait_for_server(url)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"You can also specify adapter as pinned during dynamic adapter loading. In the example below, we reload `lora2` as pinned adapter:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -407,7 +449,7 @@
|
|||||||
" url + \"/load_lora_adapter\",\n",
|
" url + \"/load_lora_adapter\",\n",
|
||||||
" json={\n",
|
" json={\n",
|
||||||
" \"lora_name\": \"lora1\",\n",
|
" \"lora_name\": \"lora1\",\n",
|
||||||
" \"lora_path\": lora1,\n",
|
" \"lora_path\": \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\",\n",
|
||||||
" \"pinned\": True, # Pin the adapter to GPU\n",
|
" \"pinned\": True, # Pin the adapter to GPU\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
")"
|
")"
|
||||||
@@ -417,7 +459,7 @@
|
|||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"Verify that the result is identical as before:"
|
"Verify that the results are expected:"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -431,17 +473,19 @@
|
|||||||
" \"text\": [\n",
|
" \"text\": [\n",
|
||||||
" \"List 3 countries and their capitals.\",\n",
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
" \"List 3 countries and their capitals.\",\n",
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
|
" \"List 3 countries and their capitals.\",\n",
|
||||||
" ],\n",
|
" ],\n",
|
||||||
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
" \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
||||||
" # The first input uses lora0, and the second input uses lora1\n",
|
" # The first input uses lora0, and the second input uses lora1\n",
|
||||||
" \"lora_path\": [\"lora0\", \"lora1\"],\n",
|
" \"lora_path\": [\"lora0\", \"lora1\", \"lora2\"],\n",
|
||||||
"}\n",
|
"}\n",
|
||||||
"response = requests.post(\n",
|
"response = requests.post(\n",
|
||||||
" url + \"/generate\",\n",
|
" url + \"/generate\",\n",
|
||||||
" json=json_data,\n",
|
" json=json_data,\n",
|
||||||
")\n",
|
")\n",
|
||||||
"print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n",
|
"print(f\"Output from lora0 (pinned): \\n{response.json()[0]['text']}\\n\")\n",
|
||||||
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")"
|
"print(f\"Output from lora1 (pinned): \\n{response.json()[1]['text']}\\n\")\n",
|
||||||
|
"print(f\"Output from lora2 (not pinned): \\n{response.json()[2]['text']}\\n\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -179,7 +179,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
|
|||||||
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
|
| `--enable-lora` | Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility. | False |
|
||||||
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
|
| `--max-lora-rank` | The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup. | None |
|
||||||
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
|
| `--lora-target-modules` | The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters. | None |
|
||||||
| `--lora-paths` | The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}. | None |
|
| `--lora-paths` | The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool} | None |
|
||||||
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
| `--max-loras-per-batch` | Maximum number of adapters for a running batch, include base-only request. | 8 |
|
||||||
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
|
| `--max-loaded-loras` | If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`. | None |
|
||||||
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
| `--lora-backend` | Choose the kernel backend for multi-LoRA serving. | triton |
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class LoRAManager:
|
|||||||
tp_rank: int = 0,
|
tp_rank: int = 0,
|
||||||
max_lora_rank: Optional[int] = None,
|
max_lora_rank: Optional[int] = None,
|
||||||
target_modules: Optional[Iterable[str]] = None,
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
lora_paths: Optional[List[LoRARef]] = None,
|
||||||
):
|
):
|
||||||
self.base_model: torch.nn.Module = base_model
|
self.base_model: torch.nn.Module = base_model
|
||||||
self.base_hf_config: AutoConfig = base_hf_config
|
self.base_hf_config: AutoConfig = base_hf_config
|
||||||
@@ -370,7 +370,7 @@ class LoRAManager:
|
|||||||
self,
|
self,
|
||||||
max_lora_rank: Optional[int] = None,
|
max_lora_rank: Optional[int] = None,
|
||||||
target_modules: Optional[Iterable[str]] = None,
|
target_modules: Optional[Iterable[str]] = None,
|
||||||
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
lora_paths: Optional[List[LoRARef]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the internal (mutable) state of the LoRAManager.
|
Initialize the internal (mutable) state of the LoRAManager.
|
||||||
@@ -392,7 +392,7 @@ class LoRAManager:
|
|||||||
self.init_memory_pool()
|
self.init_memory_pool()
|
||||||
self.update_lora_info()
|
self.update_lora_info()
|
||||||
|
|
||||||
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
def init_lora_adapters(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||||
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
||||||
self.configs: Dict[str, LoRAConfig] = {}
|
self.configs: Dict[str, LoRAConfig] = {}
|
||||||
|
|
||||||
@@ -406,7 +406,7 @@ class LoRAManager:
|
|||||||
self.num_pinned_loras: int = 0
|
self.num_pinned_loras: int = 0
|
||||||
|
|
||||||
if lora_paths:
|
if lora_paths:
|
||||||
for lora_ref in lora_paths.values():
|
for lora_ref in lora_paths:
|
||||||
result = self.load_lora_adapter(lora_ref)
|
result = self.load_lora_adapter(lora_ref)
|
||||||
if not result.success:
|
if not result.success:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
@@ -59,9 +59,9 @@ class LoRARegistry:
|
|||||||
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
def __init__(self, lora_paths: Optional[List[LoRARef]] = None):
|
||||||
assert lora_paths is None or all(
|
assert lora_paths is None or all(
|
||||||
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
isinstance(lora, LoRARef) for lora in lora_paths
|
||||||
), (
|
), (
|
||||||
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
||||||
"Please file an issue if you see this error."
|
"Please file an issue if you see this error."
|
||||||
@@ -78,7 +78,7 @@ class LoRARegistry:
|
|||||||
|
|
||||||
# Initialize the registry with provided LoRA paths, if present.
|
# Initialize the registry with provided LoRA paths, if present.
|
||||||
if lora_paths:
|
if lora_paths:
|
||||||
for lora_ref in lora_paths.values():
|
for lora_ref in lora_paths:
|
||||||
self._register_adapter(lora_ref)
|
self._register_adapter(lora_ref)
|
||||||
|
|
||||||
async def register(self, lora_ref: LoRARef):
|
async def register(self, lora_ref: LoRARef):
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ class TokenizerManager:
|
|||||||
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
# The registry dynamically updates as adapters are loaded / unloaded during runtime. It
|
||||||
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
# serves as the source of truth for available adapters and maps user-friendly LoRA names
|
||||||
# to internally used unique LoRA IDs.
|
# to internally used unique LoRA IDs.
|
||||||
self.lora_registry = LoRARegistry(self.server_args.lora_paths or {})
|
self.lora_registry = LoRARegistry(self.server_args.lora_paths)
|
||||||
# Lock to serialize LoRA update operations.
|
# Lock to serialize LoRA update operations.
|
||||||
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
# Please note that, unlike `model_update_lock`, this does not block inference, allowing
|
||||||
# LoRA updates and inference to overlap.
|
# LoRA updates and inference to overlap.
|
||||||
|
|||||||
@@ -153,7 +153,9 @@ class ServerArgs:
|
|||||||
enable_lora: Optional[bool] = None
|
enable_lora: Optional[bool] = None
|
||||||
max_lora_rank: Optional[int] = None
|
max_lora_rank: Optional[int] = None
|
||||||
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
||||||
lora_paths: Optional[Union[dict[str, str], dict[str, LoRARef], List[str]]] = None
|
lora_paths: Optional[
|
||||||
|
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
|
||||||
|
] = None
|
||||||
max_loaded_loras: Optional[int] = None
|
max_loaded_loras: Optional[int] = None
|
||||||
max_loras_per_batch: int = 8
|
max_loras_per_batch: int = 8
|
||||||
lora_backend: str = "triton"
|
lora_backend: str = "triton"
|
||||||
@@ -1319,7 +1321,7 @@ class ServerArgs:
|
|||||||
nargs="*",
|
nargs="*",
|
||||||
default=None,
|
default=None,
|
||||||
action=LoRAPathAction,
|
action=LoRAPathAction,
|
||||||
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}.",
|
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-loras-per-batch",
|
"--max-loras-per-batch",
|
||||||
@@ -2086,28 +2088,42 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_lora:
|
if self.enable_lora:
|
||||||
# Normalize lora_paths to a dictionary if it is a list.
|
|
||||||
# TODO (lifuhuang): support specifying pinned adapters in server_args.
|
|
||||||
if isinstance(self.lora_paths, list):
|
if isinstance(self.lora_paths, list):
|
||||||
lora_paths = self.lora_paths
|
lora_paths = self.lora_paths
|
||||||
self.lora_paths = {}
|
self.lora_paths = []
|
||||||
for lora_path in lora_paths:
|
for lora_path in lora_paths:
|
||||||
if "=" in lora_path:
|
if isinstance(lora_path, str):
|
||||||
name, path = lora_path.split("=", 1)
|
if "=" in lora_path:
|
||||||
self.lora_paths[name] = LoRARef(
|
name, path = lora_path.split("=", 1)
|
||||||
lora_name=name, lora_path=path, pinned=False
|
lora_ref = LoRARef(
|
||||||
|
lora_name=name, lora_path=path, pinned=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_ref = LoRARef(
|
||||||
|
lora_name=lora_path, lora_path=lora_path, pinned=False
|
||||||
|
)
|
||||||
|
elif isinstance(lora_path, dict):
|
||||||
|
assert (
|
||||||
|
"lora_name" in lora_path and "lora_path" in lora_path
|
||||||
|
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
|
||||||
|
lora_ref = LoRARef(
|
||||||
|
lora_name=lora_path["lora_name"],
|
||||||
|
lora_path=lora_path["lora_path"],
|
||||||
|
pinned=lora_path.get("pinned", False),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.lora_paths[lora_path] = LoRARef(
|
raise ValueError(
|
||||||
lora_name=lora_path, lora_path=lora_path, pinned=False
|
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
|
||||||
|
"Expected a string or a dictionary."
|
||||||
)
|
)
|
||||||
|
self.lora_paths.append(lora_ref)
|
||||||
elif isinstance(self.lora_paths, dict):
|
elif isinstance(self.lora_paths, dict):
|
||||||
self.lora_paths = {
|
self.lora_paths = [
|
||||||
k: LoRARef(lora_name=k, lora_path=v, pinned=False)
|
LoRARef(lora_name=k, lora_path=v, pinned=False)
|
||||||
for k, v in self.lora_paths.items()
|
for k, v in self.lora_paths.items()
|
||||||
}
|
]
|
||||||
elif self.lora_paths is None:
|
elif self.lora_paths is None:
|
||||||
self.lora_paths = {}
|
self.lora_paths = []
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
||||||
@@ -2134,9 +2150,7 @@ class ServerArgs:
|
|||||||
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
||||||
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
||||||
)
|
)
|
||||||
assert (
|
assert len(self.lora_paths) <= self.max_loaded_loras, (
|
||||||
not self.lora_paths or len(self.lora_paths) <= self.max_loaded_loras
|
|
||||||
), (
|
|
||||||
"The number of LoRA paths should not exceed max_loaded_loras. "
|
"The number of LoRA paths should not exceed max_loaded_loras. "
|
||||||
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
||||||
)
|
)
|
||||||
@@ -2357,13 +2371,22 @@ class PortArgs:
|
|||||||
|
|
||||||
class LoRAPathAction(argparse.Action):
|
class LoRAPathAction(argparse.Action):
|
||||||
def __call__(self, parser, namespace, values, option_string=None):
|
def __call__(self, parser, namespace, values, option_string=None):
|
||||||
setattr(namespace, self.dest, {})
|
lora_paths = []
|
||||||
for lora_path in values:
|
if values:
|
||||||
if "=" in lora_path:
|
assert isinstance(values, list), "Expected a list of LoRA paths."
|
||||||
name, path = lora_path.split("=", 1)
|
for lora_path in values:
|
||||||
getattr(namespace, self.dest)[name] = path
|
lora_path = lora_path.strip()
|
||||||
else:
|
if lora_path.startswith("{") and lora_path.endswith("}"):
|
||||||
getattr(namespace, self.dest)[lora_path] = lora_path
|
obj = json.loads(lora_path)
|
||||||
|
assert "lora_path" in obj and "lora_name" in obj, (
|
||||||
|
f"{repr(lora_path)} looks like a JSON str, "
|
||||||
|
"but it does not contain 'lora_name' and 'lora_path' keys."
|
||||||
|
)
|
||||||
|
lora_paths.append(obj)
|
||||||
|
else:
|
||||||
|
lora_paths.append(lora_path)
|
||||||
|
|
||||||
|
setattr(namespace, self.dest, lora_paths)
|
||||||
|
|
||||||
|
|
||||||
class DeprecatedAction(argparse.Action):
|
class DeprecatedAction(argparse.Action):
|
||||||
|
|||||||
@@ -491,7 +491,7 @@ class SRTRunner:
|
|||||||
tp_size: int = 1,
|
tp_size: int = 1,
|
||||||
model_impl: str = "auto",
|
model_impl: str = "auto",
|
||||||
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
||||||
lora_paths: List[str] = None,
|
lora_paths: Optional[Union[List[str], List[dict[str, str]]]] = None,
|
||||||
max_loras_per_batch: int = 4,
|
max_loras_per_batch: int = 4,
|
||||||
attention_backend: Optional[str] = None,
|
attention_backend: Optional[str] = None,
|
||||||
prefill_attention_backend: Optional[str] = None,
|
prefill_attention_backend: Optional[str] = None,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
|
import json
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import unittest
|
import unittest
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -89,8 +90,35 @@ BASIC_TESTS = [
|
|||||||
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
],
|
],
|
||||||
initial_adapters=["philschmid/code-llama-3-1-8b-text-to-sql-lora"],
|
initial_adapters=[
|
||||||
|
# Testing 3 supported lora-path formats.
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
{
|
||||||
|
"lora_name": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
"lora_path": "pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
"pinned": False,
|
||||||
|
},
|
||||||
|
],
|
||||||
op_sequence=[
|
op_sequence=[
|
||||||
|
Operation(
|
||||||
|
type=OperationType.FORWARD,
|
||||||
|
data=create_batch_data(
|
||||||
|
[
|
||||||
|
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
||||||
|
"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
"pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
]
|
||||||
|
),
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
data=create_batch_data("philschmid/code-llama-3-1-8b-text-to-sql-lora"),
|
||||||
@@ -147,6 +175,10 @@ BASIC_TESTS = [
|
|||||||
type=OperationType.UNLOAD,
|
type=OperationType.UNLOAD,
|
||||||
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
data="Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16",
|
||||||
),
|
),
|
||||||
|
Operation(
|
||||||
|
type=OperationType.UNLOAD,
|
||||||
|
data="pbevan11/llama-3.1-8b-ocr-correction",
|
||||||
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data(
|
data=create_batch_data(
|
||||||
@@ -157,18 +189,12 @@ BASIC_TESTS = [
|
|||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
data=create_batch_data("pbevan11/llama-3.1-8b-ocr-correction"),
|
||||||
),
|
expected_error="not loaded",
|
||||||
Operation(
|
|
||||||
type=OperationType.LOAD,
|
|
||||||
data="philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
||||||
),
|
),
|
||||||
Operation(
|
Operation(
|
||||||
type=OperationType.FORWARD,
|
type=OperationType.FORWARD,
|
||||||
data=create_batch_data(
|
data=create_batch_data(
|
||||||
[
|
None,
|
||||||
"philschmid/code-llama-3-1-8b-text-to-sql-lora",
|
|
||||||
"pbevan11/llama-3.1-8b-ocr-correction",
|
|
||||||
]
|
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
@@ -705,7 +731,7 @@ class LoRAUpdateTestSessionBase:
|
|||||||
*,
|
*,
|
||||||
testcase: Optional[TestCase],
|
testcase: Optional[TestCase],
|
||||||
model_path: str,
|
model_path: str,
|
||||||
lora_paths: list[str],
|
lora_paths: List[Union[str, dict]],
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
max_loaded_loras: Optional[int] = None,
|
max_loaded_loras: Optional[int] = None,
|
||||||
max_lora_rank: Optional[int],
|
max_lora_rank: Optional[int],
|
||||||
@@ -727,7 +753,17 @@ class LoRAUpdateTestSessionBase:
|
|||||||
self.cuda_graph_max_bs = cuda_graph_max_bs
|
self.cuda_graph_max_bs = cuda_graph_max_bs
|
||||||
self.enable_lora = enable_lora
|
self.enable_lora = enable_lora
|
||||||
|
|
||||||
self.expected_adapters = set(lora_paths or [])
|
self.expected_adapters = set()
|
||||||
|
if self.lora_paths:
|
||||||
|
for adapter in self.lora_paths:
|
||||||
|
if isinstance(adapter, dict):
|
||||||
|
lora_name = adapter["lora_name"]
|
||||||
|
elif "=" in adapter:
|
||||||
|
lora_name = adapter.split("=")[0]
|
||||||
|
else:
|
||||||
|
lora_name = adapter
|
||||||
|
self.expected_adapters.add(lora_name)
|
||||||
|
|
||||||
self.handle = None # Will be set in __enter__
|
self.handle = None # Will be set in __enter__
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@@ -926,7 +962,11 @@ class LoRAUpdateServerTestSession(LoRAUpdateTestSessionBase):
|
|||||||
if self.enable_lora:
|
if self.enable_lora:
|
||||||
other_args.append("--enable-lora")
|
other_args.append("--enable-lora")
|
||||||
if self.lora_paths:
|
if self.lora_paths:
|
||||||
other_args.extend(["--lora-paths"] + self.lora_paths)
|
other_args.append("--lora-paths")
|
||||||
|
for lora_path in self.lora_paths:
|
||||||
|
if isinstance(lora_path, dict):
|
||||||
|
lora_path = json.dumps(lora_path)
|
||||||
|
other_args.append(lora_path)
|
||||||
if self.disable_cuda_graph:
|
if self.disable_cuda_graph:
|
||||||
other_args.append("--disable-cuda-graph")
|
other_args.append("--disable-cuda-graph")
|
||||||
if self.max_lora_rank is not None:
|
if self.max_lora_rank is not None:
|
||||||
@@ -1093,7 +1133,7 @@ class TestLoRADynamicUpdate(CustomTestCase):
|
|||||||
self,
|
self,
|
||||||
mode: LoRAUpdateTestSessionMode,
|
mode: LoRAUpdateTestSessionMode,
|
||||||
base: str,
|
base: str,
|
||||||
initial_adapters: List[str],
|
initial_adapters: List[Union[str, dict]],
|
||||||
op_sequence: List[Operation],
|
op_sequence: List[Operation],
|
||||||
max_loras_per_batch: int,
|
max_loras_per_batch: int,
|
||||||
max_loaded_loras: Optional[int] = None,
|
max_loaded_loras: Optional[int] = None,
|
||||||
|
|||||||
Reference in New Issue
Block a user