Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (#6861)
This commit is contained in:
@@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module):
|
||||
self.base_hf_config.hidden_size
|
||||
// self.base_hf_config.num_attention_heads
|
||||
)
|
||||
weights[q_name], weights[kv_name] = torch.split(
|
||||
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
||||
weights[qkv_name],
|
||||
[
|
||||
head_size * self.base_hf_config.num_attention_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads * 2,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
head_size * self.base_hf_config.num_key_value_heads,
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
weights[kv_name] = torch.stack(
|
||||
[k_proj_weight, v_proj_weight],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
def normalize_gate_up_proj(
|
||||
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
||||
|
||||
@@ -157,6 +157,10 @@ class LoRAMemoryPool:
|
||||
def load_lora_weight_to_buffer(
|
||||
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
|
||||
):
|
||||
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
|
||||
assert (
|
||||
buffer_view.shape == weight.shape
|
||||
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
|
||||
|
||||
if uid is None:
|
||||
for i in range(self.num_layer):
|
||||
@@ -208,21 +212,27 @@ class LoRAMemoryPool:
|
||||
|
||||
for name, weights in temp_A_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
|
||||
weights
|
||||
)
|
||||
buffer_view = self.A_buffer[name][layer_id][buffer_id][
|
||||
: lora_rank * c, :
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
|
||||
for name, weights in temp_B_buffer.items():
|
||||
c = get_stacked_multiply(name)
|
||||
if c > 1:
|
||||
for stacked_id in range(c):
|
||||
self.B_buffer[name][layer_id][stacked_id][buffer_id][
|
||||
:, :lora_rank
|
||||
].copy_(weights[stacked_id])
|
||||
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
||||
buffer_id
|
||||
][:, :lora_rank]
|
||||
check_lora_weight_shape(buffer_view, weights[stacked_id])
|
||||
buffer_view.copy_(weights[stacked_id])
|
||||
else:
|
||||
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
|
||||
weights
|
||||
)
|
||||
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
||||
:, :lora_rank
|
||||
]
|
||||
check_lora_weight_shape(buffer_view, weights)
|
||||
buffer_view.copy_(weights)
|
||||
|
||||
def get_tensor(
|
||||
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
||||
|
||||
Reference in New Issue
Block a user