Introduce Stable LoRA ID System for Overlapped Updates and Prefix Caching (#8261)

This commit is contained in:
Lifu Huang
2025-07-23 00:32:16 -07:00
committed by GitHub
parent e885bfdc6a
commit 8abd3e77fe
11 changed files with 400 additions and 261 deletions

View File

@@ -12,6 +12,7 @@
# limitations under the License.
# ==============================================================================
import contextlib
import multiprocessing as mp
import unittest
from typing import Dict, List, Tuple
@@ -39,6 +40,16 @@ ADAPTERS = [
BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
@contextlib.contextmanager
def dynamically_loaded_adapter(runner, lora_path: str, lora_name: str):
"""A context manager to load and automatically unload a LoRA adapter."""
try:
runner.load_lora_adapter(lora_name=lora_name, lora_path=lora_path)
yield
finally:
runner.unload_lora_adapter(lora_name=lora_name)
class TestLoRAEviction(CustomTestCase):
def test_lora_eviction_with_different_target_modules(self):
"""
@@ -51,55 +62,80 @@ class TestLoRAEviction(CustomTestCase):
self._run_test(ADAPTERS, output_history, reverse=False)
self._run_test(ADAPTERS, output_history, reverse=True)
def test_lora_eviction_with_reused_lora_name(self):
"""
Test LoRA eviction with reused LoRA names.
This test runs inference against two LoRA adapters with the same name to ensure that the eviction behavior
works correctly when reusing LoRA names.
"""
output_history = {}
self._run_test(ADAPTERS, output_history, reuse_lora_name=True, repeat=1)
self._run_test(ADAPTERS, output_history, reuse_lora_name=False, repeat=1)
def _run_test(
self,
lora_paths: List[str],
output_history: Dict[Tuple[str, str], str],
reverse: bool,
reverse: bool = False,
repeat: int = 2,
reuse_lora_name: bool = False,
):
REUSED_LORA_NAME = "lora"
max_new_tokens = 256
backend = "triton"
torch_dtype = torch.float16
base_path = BASE_MODEL
assert len(lora_paths) >= 2
initial_lora_paths = lora_paths if not reuse_lora_name else None
# Initialize runners
with SRTRunner(
base_path,
torch_dtype=torch_dtype,
model_type="generation",
lora_paths=lora_paths,
lora_paths=initial_lora_paths,
max_loras_per_batch=1,
lora_backend=backend,
disable_radix_cache=True,
enable_lora=True,
max_lora_rank=256,
lora_target_modules=["all"],
) as srt_runner:
adapter_sequence = lora_paths if not reverse else lora_paths[::-1]
for i in range(repeat):
for j, adapter in enumerate(adapter_sequence):
for j, lora_path in enumerate(adapter_sequence):
print(
f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---"
f"\n========== Testing LoRA eviction with adapter '{lora_path}' (#{j + 1}/{len(adapter_sequence)}), reuse_lora_name: {reuse_lora_name}, reversed: {reverse}, repeat: {i + 1}/{repeat} ---"
)
for prompt in PROMPTS:
print("\nprompt:\n", prompt)
srt_outputs = srt_runner.forward(
[prompt],
max_new_tokens=max_new_tokens,
lora_paths=[adapter],
)
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((adapter, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
lora_name = REUSED_LORA_NAME if reuse_lora_name else lora_path
context = (
dynamically_loaded_adapter(srt_runner, lora_path, lora_name)
if reuse_lora_name
else contextlib.nullcontext()
)
with context:
for prompt in PROMPTS:
print("\nprompt:\n", prompt)
srt_outputs = srt_runner.forward(
[prompt],
max_new_tokens=max_new_tokens,
lora_paths=[lora_name],
)
else:
output_history[(adapter, prompt)] = output
output = srt_outputs.output_strs[0].strip()
print("\noutput:\n", output)
prev_output = output_history.get((lora_path, prompt))
if prev_output is not None:
self.assertEqual(
prev_output,
output,
f"Output mismatch for adapter {lora_path} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.",
)
else:
output_history[(lora_path, prompt)] = output
if __name__ == "__main__":

View File

@@ -14,7 +14,7 @@ class TestFile:
suites = {
"per-commit": [
TestFile("models/lora/test_lora.py", 200),
TestFile("models/lora/test_lora_eviction.py", 120),
TestFile("models/lora/test_lora_eviction.py", 200),
TestFile("models/lora/test_lora_backend.py", 99),
TestFile("models/lora/test_multi_lora_backend.py", 60),
TestFile("models/lora/test_lora_cuda_graph.py", 250),