From 3de617a75bc9682763ba4f5f402a679e0df5dd22 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 19 Jul 2025 13:14:08 -0700 Subject: [PATCH] Fix LoRA buffer contamination during adapter eviction (#8103) --- python/sglang/srt/lora/mem_pool.py | 51 +++++++--- test/srt/models/lora/test_lora_eviction.py | 111 +++++++++++++++++++++ test/srt/run_suite.py | 1 + 3 files changed, 148 insertions(+), 15 deletions(-) create mode 100644 test/srt/models/lora/test_lora_eviction.py diff --git a/python/sglang/srt/lora/mem_pool.py b/python/sglang/srt/lora/mem_pool.py index 713b03650..1b36cac5e 100644 --- a/python/sglang/srt/lora/mem_pool.py +++ b/python/sglang/srt/lora/mem_pool.py @@ -188,10 +188,18 @@ class LoRAMemoryPool: lora_adapter: LoRAAdapter, lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]], ): - def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor): - assert ( - buffer_view.shape == weight.shape - ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." + def load_lora_weight_tensor( + buffer_view: torch.Tensor, weight: Optional[torch.Tensor] + ): + if weight is None: + # If the particular weight is not present in the adapter, we initialize the buffer to zero + # to avoid contamination from the residual weight of the evicted adapters. + buffer_view.zero_() + else: + assert ( + buffer_view.shape == weight.shape + ), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}." + buffer_view.copy_(weight) if uid is None: for i in range(self.num_layer): @@ -203,8 +211,12 @@ class LoRAMemoryPool: lora_rank = lora_adapter.config.hf_config["r"] for layer_id in range(self.num_layer): layer_weights = lora_adapter.layers[layer_id].weights - temp_A_buffer: Dict[str, torch.Tensor] = {} - temp_B_buffer: Dict[str, torch.Tensor] = {} + temp_A_buffer: Dict[str, Optional[torch.Tensor]] = { + weight_name: None for weight_name in self.A_buffer + } + temp_B_buffer: Dict[str, Optional[torch.Tensor]] = { + weight_name: None for weight_name in self.B_buffer + } for name, weights in layer_weights.items(): if "lora_A" in name: lora_weight_name = get_weight_name( @@ -220,6 +232,14 @@ class LoRAMemoryPool: if self.tp_size > 1: cur_layer_modules = lora_modules[layer_id] for module_name, module in cur_layer_modules.items(): + weight_name = get_weight_name( + module_name, self.lora_weight_names, LoRAType.LORA_A + ) + + if temp_A_buffer[weight_name] is None: + # Skip weight slicing if the weight is not present in the adapter + continue + if "qkv_proj" in module_name: temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights( temp_A_buffer["qkv_proj"], self.tp_rank @@ -231,9 +251,10 @@ class LoRAMemoryPool: ) ) else: - weight_name = get_weight_name( - module_name, self.lora_weight_names, LoRAType.LORA_A - ) + # TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B. + # Currently, we're reusing A's weight name as a workaround, relying on the fact that A and + # B share the same name except for `qkv_proj`. We should clean this up once we deprecate the + # FlashInfer LoRA backend. temp_A_buffer[weight_name] = module.slice_lora_a_weights( temp_A_buffer[weight_name], self.tp_rank ) @@ -246,8 +267,7 @@ class LoRAMemoryPool: buffer_view = self.A_buffer[name][layer_id][buffer_id][ : lora_rank * c, : ] - check_lora_weight_shape(buffer_view, weights) - buffer_view.copy_(weights) + load_lora_weight_tensor(buffer_view, weights) for name, weights in temp_B_buffer.items(): c = get_stacked_multiply(name) @@ -256,14 +276,15 @@ class LoRAMemoryPool: buffer_view = self.B_buffer[name][layer_id][stacked_id][ buffer_id ][:, :lora_rank] - check_lora_weight_shape(buffer_view, weights[stacked_id]) - buffer_view.copy_(weights[stacked_id]) + weight_slice = ( + weights[stacked_id] if weights is not None else None + ) + load_lora_weight_tensor(buffer_view, weight_slice) else: buffer_view = self.B_buffer[name][layer_id][0][buffer_id][ :, :lora_rank ] - check_lora_weight_shape(buffer_view, weights) - buffer_view.copy_(weights) + load_lora_weight_tensor(buffer_view, weights) def get_tensor( self, weight_name: str, layer_id: int, lora_type: LoRAType diff --git a/test/srt/models/lora/test_lora_eviction.py b/test/srt/models/lora/test_lora_eviction.py new file mode 100644 index 000000000..e74af0a0e --- /dev/null +++ b/test/srt/models/lora/test_lora_eviction.py @@ -0,0 +1,111 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing as mp +import unittest +from typing import Dict, List, Tuple + +import torch + +from sglang.test.runners import SRTRunner +from sglang.test.test_utils import CustomTestCase + +PROMPTS = [ + "AI is a field of computer science focused on", + """ + ### Instruction: + Compose a SQL query that uses the following table: users, and returns the user_id and name of all users whose name that does not have a duplicate in the table. + ### Response: + SELECT user_id, name FROM users WHERE name LIKE 'A%'; + """, +] + +ADAPTERS = [ + "faridlazuarda/valadapt-llama-3.1-8B-it-chinese", # target_modules = q, v + "philschmid/code-llama-3-1-8b-text-to-sql-lora", # target_modules = q, k, v, o, gate, up, down +] + +BASE_MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + +class TestLoRAEviction(CustomTestCase): + def test_lora_eviction_with_different_target_modules(self): + """ + Test LoRA eviction with different target modules. + + This test runs inference against two LoRA adapters in different orders to force eviction behavior, and ensures + that the outputs of the same (adapter, prompt) pair are consistent across runs. + """ + output_history = {} + self._run_test(ADAPTERS, output_history, reverse=False) + self._run_test(ADAPTERS, output_history, reverse=True) + + def _run_test( + self, + lora_paths: List[str], + output_history: Dict[Tuple[str, str], str], + reverse: bool, + repeat: int = 2, + ): + max_new_tokens = 256 + backend = "triton" + torch_dtype = torch.float16 + base_path = BASE_MODEL + assert len(lora_paths) >= 2 + + # Initialize runners + with SRTRunner( + base_path, + torch_dtype=torch_dtype, + model_type="generation", + lora_paths=lora_paths, + max_loras_per_batch=1, + lora_backend=backend, + disable_radix_cache=True, + ) as srt_runner: + adapter_sequence = lora_paths if not reverse else lora_paths[::-1] + + for i in range(repeat): + for j, adapter in enumerate(adapter_sequence): + print( + f"\n========== Testing LoRA eviction with adapter '{adapter}' (#{j+1}/{len(adapter_sequence)}), reversed: {reverse}, repeat: {i+1}/{repeat} ---" + ) + for prompt in PROMPTS: + print("\nprompt:\n", prompt) + srt_outputs = srt_runner.forward( + [prompt], + max_new_tokens=max_new_tokens, + lora_paths=[adapter], + ) + output = srt_outputs.output_strs[0].strip() + print("\noutput:\n", output) + + prev_output = output_history.get((adapter, prompt)) + if prev_output is not None: + self.assertEqual( + prev_output, + output, + f"Output mismatch for adapter {adapter} and prompt '{prompt}' on repeat {j + 1}, previous: '{prev_output}', current: '{output}'.", + ) + else: + output_history[(adapter, prompt)] = output + + +if __name__ == "__main__": + try: + mp.set_start_method("spawn") + except RuntimeError: + pass + + unittest.main(warnings="ignore") diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e67362cf8..f59aed623 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -14,6 +14,7 @@ class TestFile: suites = { "per-commit": [ TestFile("models/lora/test_lora.py", 200), + TestFile("models/lora/test_lora_eviction.py", 120), TestFile("models/lora/test_lora_backend.py", 99), TestFile("models/lora/test_multi_lora_backend.py", 60), TestFile("models/lora/test_lora_cuda_graph.py", 250),