Fix LoRA buffer contamination during adapter eviction (#8103)

This commit is contained in:
Lifu Huang
2025-07-19 13:14:08 -07:00
committed by GitHub
parent bb0e8a32b5
commit 3de617a75b
3 changed files with 148 additions and 15 deletions

View File

@@ -188,10 +188,18 @@ class LoRAMemoryPool:
lora_adapter: LoRAAdapter,
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
):
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
assert (
buffer_view.shape == weight.shape
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
def load_lora_weight_tensor(
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
):
if weight is None:
# If the particular weight is not present in the adapter, we initialize the buffer to zero
# to avoid contamination from the residual weight of the evicted adapters.
buffer_view.zero_()
else:
assert (
buffer_view.shape == weight.shape
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
buffer_view.copy_(weight)
if uid is None:
for i in range(self.num_layer):
@@ -203,8 +211,12 @@ class LoRAMemoryPool:
lora_rank = lora_adapter.config.hf_config["r"]
for layer_id in range(self.num_layer):
layer_weights = lora_adapter.layers[layer_id].weights
temp_A_buffer: Dict[str, torch.Tensor] = {}
temp_B_buffer: Dict[str, torch.Tensor] = {}
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.A_buffer
}
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
weight_name: None for weight_name in self.B_buffer
}
for name, weights in layer_weights.items():
if "lora_A" in name:
lora_weight_name = get_weight_name(
@@ -220,6 +232,14 @@ class LoRAMemoryPool:
if self.tp_size > 1:
cur_layer_modules = lora_modules[layer_id]
for module_name, module in cur_layer_modules.items():
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
if temp_A_buffer[weight_name] is None:
# Skip weight slicing if the weight is not present in the adapter
continue
if "qkv_proj" in module_name:
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
temp_A_buffer["qkv_proj"], self.tp_rank
@@ -231,9 +251,10 @@ class LoRAMemoryPool:
)
)
else:
weight_name = get_weight_name(
module_name, self.lora_weight_names, LoRAType.LORA_A
)
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
# FlashInfer LoRA backend.
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
temp_A_buffer[weight_name], self.tp_rank
)
@@ -246,8 +267,7 @@ class LoRAMemoryPool:
buffer_view = self.A_buffer[name][layer_id][buffer_id][
: lora_rank * c, :
]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
load_lora_weight_tensor(buffer_view, weights)
for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name)
@@ -256,14 +276,15 @@ class LoRAMemoryPool:
buffer_view = self.B_buffer[name][layer_id][stacked_id][
buffer_id
][:, :lora_rank]
check_lora_weight_shape(buffer_view, weights[stacked_id])
buffer_view.copy_(weights[stacked_id])
weight_slice = (
weights[stacked_id] if weights is not None else None
)
load_lora_weight_tensor(buffer_view, weight_slice)
else:
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
:, :lora_rank
]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
load_lora_weight_tensor(buffer_view, weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType