Fix LoRA buffer contamination during adapter eviction (#8103)
This commit is contained in:
@@ -188,10 +188,18 @@ class LoRAMemoryPool:
|
||||
lora_adapter: LoRAAdapter,
|
||||
lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]],
|
||||
):
|
||||
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
||||
assert (
|
||||
buffer_view.shape == weight.shape
|
||||
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
||||
def load_lora_weight_tensor(
|
||||
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|
||||
):
|
||||
if weight is None:
|
||||
# If the particular weight is not present in the adapter, we initialize the buffer to zero
|
||||
# to avoid contamination from the residual weight of the evicted adapters.
|
||||
buffer_view.zero_()
|
||||
else:
|
||||
assert (
|
||||
buffer_view.shape == weight.shape
|
||||
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
||||
buffer_view.copy_(weight)
|
||||
|
||||
if uid is None:
|
||||
for i in range(self.num_layer):
|
||||
@@ -203,8 +211,12 @@ class LoRAMemoryPool:
|
||||
lora_rank = lora_adapter.config.hf_config["r"]
|
||||
for layer_id in range(self.num_layer):
|
||||
layer_weights = lora_adapter.layers[layer_id].weights
|
||||
temp_A_buffer: Dict[str, torch.Tensor] = {}
|
||||
temp_B_buffer: Dict[str, torch.Tensor] = {}
|
||||
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||
weight_name: None for weight_name in self.A_buffer
|
||||
}
|
||||
temp_B_buffer: Dict[str, Optional[torch.Tensor]] = {
|
||||
weight_name: None for weight_name in self.B_buffer
|
||||
}
|
||||
for name, weights in layer_weights.items():
|
||||
if "lora_A" in name:
|
||||
lora_weight_name = get_weight_name(
|
||||
@@ -220,6 +232,14 @@ class LoRAMemoryPool:
|
||||
if self.tp_size > 1:
|
||||
cur_layer_modules = lora_modules[layer_id]
|
||||
for module_name, module in cur_layer_modules.items():
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
|
||||
if temp_A_buffer[weight_name] is None:
|
||||
# Skip weight slicing if the weight is not present in the adapter
|
||||
continue
|
||||
|
||||
if "qkv_proj" in module_name:
|
||||
temp_A_buffer["qkv_proj"] = module.slice_lora_a_weights(
|
||||
temp_A_buffer["qkv_proj"], self.tp_rank
|
||||
@@ -231,9 +251,10 @@ class LoRAMemoryPool:
|
||||
)
|
||||
)
|
||||
else:
|
||||
weight_name = get_weight_name(
|
||||
module_name, self.lora_weight_names, LoRAType.LORA_A
|
||||
)
|
||||
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
||||
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
||||
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
||||
# FlashInfer LoRA backend.
|
||||
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
||||
temp_A_buffer[weight_name], self.tp_rank
|
||||
)
|
||||
@@ -246,8 +267,7 @@ class LoRAMemoryPool:
|
||||
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
||||
: lora_rank * c, :
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
@@ -256,14 +276,15 @@ class LoRAMemoryPool:
|
||||
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
][:, :lora_rank]
|
||||
check_lora_weight_shape(buffer_view, weights[stacked_id])
|
||||
buffer_view.copy_(weights[stacked_id])
|
||||
weight_slice = (
|
||||
weights[stacked_id] if weights is not None else None
|
||||
)
|
||||
load_lora_weight_tensor(buffer_view, weight_slice)
|
||||
else:
|
||||
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
||||
:, :lora_rank
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
load_lora_weight_tensor(buffer_view, weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
Reference in New Issue
Block a user