Fix incorrect LoRA weight loading for fused gate_up_proj (#6734)
This commit is contained in:
@@ -680,8 +680,8 @@ register_conv_template(
|
|||||||
register_conv_template(
|
register_conv_template(
|
||||||
Conversation(
|
Conversation(
|
||||||
name="phi-4-mm",
|
name="phi-4-mm",
|
||||||
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
system_message="",
|
||||||
system_template="<|system|>{system_message}<|end|>",
|
system_template="{system_message}",
|
||||||
roles=("<|user|>", "<|assistant|>"),
|
roles=("<|user|>", "<|assistant|>"),
|
||||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||||
sep="<|end|>",
|
sep="<|end|>",
|
||||||
|
|||||||
@@ -209,4 +209,12 @@ class LoRAAdapter(nn.Module):
|
|||||||
gate_up_name = weight_name
|
gate_up_name = weight_name
|
||||||
if "lora_A" in weight_name:
|
if "lora_A" in weight_name:
|
||||||
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
||||||
# else: "lora_B" is already stacked, no operations is needed.
|
else:
|
||||||
|
output_dim = weights[gate_up_name].shape[0] // 2
|
||||||
|
weights[gate_up_name] = torch.stack(
|
||||||
|
[
|
||||||
|
weights[gate_up_name][:output_dim, :],
|
||||||
|
weights[gate_up_name][output_dim:, :],
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
|||||||
@@ -296,23 +296,30 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
def compute_cu_seqlens(
|
def compute_cu_seqlens(
|
||||||
self,
|
self,
|
||||||
tgt_sizes: Optional[torch.Tensor] = None,
|
tgt_sizes: Optional[torch.Tensor] = None,
|
||||||
atch_attention_mask: Optional[torch.BoolTensor] = None,
|
input_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# shape: (batch_size,)
|
# shape: (batch_size,)
|
||||||
if tgt_sizes is not None:
|
if tgt_sizes is not None:
|
||||||
patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1]
|
seqlen = tgt_sizes[:, 0] * tgt_sizes[:, 1]
|
||||||
|
elif input_embeds is not None:
|
||||||
|
seqlen = torch.full(
|
||||||
|
size=(input_embeds.shape[0],),
|
||||||
|
fill_value=input_embeds.shape[1],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=input_embeds.device,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[
|
raise ValueError(
|
||||||
:, 0, :
|
"Either `tgt_sizes` or `input_embeds` must be provided to compute cu_seqlens."
|
||||||
].sum(dim=1)
|
)
|
||||||
|
|
||||||
cu_seqlens = torch.cat(
|
cu_seqlens = torch.cat(
|
||||||
[
|
[
|
||||||
torch.tensor([0], device=patch_len.device, dtype=torch.int32),
|
torch.tensor([0], device=seqlen.device, dtype=torch.int32),
|
||||||
torch.cumsum(patch_len, dim=0, dtype=torch.int32),
|
torch.cumsum(seqlen, dim=0, dtype=torch.int32),
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
).to(patch_len.device)
|
).to(seqlen.device)
|
||||||
return cu_seqlens
|
return cu_seqlens
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -326,7 +333,7 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
patch_attention_mask=patch_attention_mask,
|
patch_attention_mask=patch_attention_mask,
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
)
|
)
|
||||||
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask)
|
cu_seqlens = self.compute_cu_seqlens(tgt_sizes, hidden_states)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
|
|||||||
@@ -451,8 +451,8 @@ class Phi4MMForCausalLM(nn.Module):
|
|||||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||||
|
|
||||||
def should_apply_lora(self, module_name: str) -> Optional[str]:
|
def should_apply_lora(self, module_name: str) -> bool:
|
||||||
return self.lora_pattern.match(module_name)
|
return bool(self.lora_pattern.match(module_name))
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
Reference in New Issue
Block a user