Support LoRA in TestOpenAIVisionServer and fix fused kv_proj loading bug. (#6861)

This commit is contained in:
Lifu Huang
2025-06-04 22:08:30 -07:00
committed by GitHub
parent 499f5e620c
commit 4474eaf552
4 changed files with 66 additions and 16 deletions

View File

@@ -165,14 +165,19 @@ class LoRAAdapter(nn.Module):
self.base_hf_config.hidden_size
// self.base_hf_config.num_attention_heads
)
weights[q_name], weights[kv_name] = torch.split(
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
weights[qkv_name],
[
head_size * self.base_hf_config.num_attention_heads,
head_size * self.base_hf_config.num_key_value_heads * 2,
head_size * self.base_hf_config.num_key_value_heads,
head_size * self.base_hf_config.num_key_value_heads,
],
dim=0,
)
weights[kv_name] = torch.stack(
[k_proj_weight, v_proj_weight],
dim=0,
)
def normalize_gate_up_proj(
self, weight_names: List[str], weights: Dict[str, torch.Tensor]

View File

@@ -157,6 +157,10 @@ class LoRAMemoryPool:
def load_lora_weight_to_buffer(
self, uid: str, buffer_id: int, lora_adapter: LoRAAdapter = None
):
def check_lora_weight_shape(buffer_view: torch.Tensor, weight: torch.Tensor):
assert (
buffer_view.shape == weight.shape
), f"LoRA buffer shape {buffer_view.shape} does not match weight shape {weight.shape}."
if uid is None:
for i in range(self.num_layer):
@@ -208,21 +212,27 @@ class LoRAMemoryPool:
for name, weights in temp_A_buffer.items():
c = get_stacked_multiply(name)
self.A_buffer[name][layer_id][buffer_id][: lora_rank * c, :].copy_(
weights
)
buffer_view = self.A_buffer[name][layer_id][buffer_id][
: lora_rank * c, :
]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
for name, weights in temp_B_buffer.items():
c = get_stacked_multiply(name)
if c > 1:
for stacked_id in range(c):
self.B_buffer[name][layer_id][stacked_id][buffer_id][
:, :lora_rank
].copy_(weights[stacked_id])
buffer_view = self.B_buffer[name][layer_id][stacked_id][
buffer_id
][:, :lora_rank]
check_lora_weight_shape(buffer_view, weights[stacked_id])
buffer_view.copy_(weights[stacked_id])
else:
self.B_buffer[name][layer_id][0][buffer_id][:, :lora_rank].copy_(
weights
)
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
:, :lora_rank
]
check_lora_weight_shape(buffer_view, weights)
buffer_view.copy_(weights)
def get_tensor(
self, weight_name: str, layer_id: int, lora_type: LoRAType