From 094fbdacd5bb27e9fdb9336cd81f36ad4fcc11d9 Mon Sep 17 00:00:00 2001 From: Lifu Huang Date: Sat, 31 May 2025 13:41:44 -0700 Subject: [PATCH] Fix incorrect LoRA weight loading for fused gate_up_proj (#6734) --- python/sglang/srt/conversation.py | 4 ++-- python/sglang/srt/lora/lora.py | 10 +++++++++- python/sglang/srt/models/idefics2.py | 25 ++++++++++++++++--------- python/sglang/srt/models/phi4mm.py | 4 ++-- 4 files changed, 29 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index cf2b0e650..914d1e721 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -680,8 +680,8 @@ register_conv_template( register_conv_template( Conversation( name="phi-4-mm", - system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", - system_template="<|system|>{system_message}<|end|>", + system_message="", + system_template="{system_message}", roles=("<|user|>", "<|assistant|>"), sep_style=SeparatorStyle.NO_COLON_SINGLE, sep="<|end|>", diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 38cc08f71..a6cbc7a28 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -209,4 +209,12 @@ class LoRAAdapter(nn.Module): gate_up_name = weight_name if "lora_A" in weight_name: weights[gate_up_name] = weights[gate_up_name].repeat(2, 1) - # else: "lora_B" is already stacked, no operations is needed. + else: + output_dim = weights[gate_up_name].shape[0] // 2 + weights[gate_up_name] = torch.stack( + [ + weights[gate_up_name][:output_dim, :], + weights[gate_up_name][output_dim:, :], + ], + dim=0, + ) diff --git a/python/sglang/srt/models/idefics2.py b/python/sglang/srt/models/idefics2.py index 3b8059dfa..75922d05c 100644 --- a/python/sglang/srt/models/idefics2.py +++ b/python/sglang/srt/models/idefics2.py @@ -296,23 +296,30 @@ class Idefics2VisionTransformer(nn.Module): def compute_cu_seqlens( self, tgt_sizes: Optional[torch.Tensor] = None, - atch_attention_mask: Optional[torch.BoolTensor] = None, + input_embeds: Optional[torch.Tensor] = None, ) -> torch.Tensor: # shape: (batch_size,) if tgt_sizes is not None: - patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + seqlen = tgt_sizes[:, 0] * tgt_sizes[:, 1] + elif input_embeds is not None: + seqlen = torch.full( + size=(input_embeds.shape[0],), + fill_value=input_embeds.shape[1], + dtype=torch.int32, + device=input_embeds.device, + ) else: - patch_len = atch_attention_mask[:, :, 0].sum(dim=1) * atch_attention_mask[ - :, 0, : - ].sum(dim=1) + raise ValueError( + "Either `tgt_sizes` or `input_embeds` must be provided to compute cu_seqlens." + ) cu_seqlens = torch.cat( [ - torch.tensor([0], device=patch_len.device, dtype=torch.int32), - torch.cumsum(patch_len, dim=0, dtype=torch.int32), + torch.tensor([0], device=seqlen.device, dtype=torch.int32), + torch.cumsum(seqlen, dim=0, dtype=torch.int32), ], dim=0, - ).to(patch_len.device) + ).to(seqlen.device) return cu_seqlens def forward( @@ -326,7 +333,7 @@ class Idefics2VisionTransformer(nn.Module): patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, ) - cu_seqlens = self.compute_cu_seqlens(tgt_sizes, patch_attention_mask) + cu_seqlens = self.compute_cu_seqlens(tgt_sizes, hidden_states) encoder_outputs = self.encoder( hidden_states, cu_seqlens=cu_seqlens, diff --git a/python/sglang/srt/models/phi4mm.py b/python/sglang/srt/models/phi4mm.py index a574dc27c..2626e9641 100644 --- a/python/sglang/srt/models/phi4mm.py +++ b/python/sglang/srt/models/phi4mm.py @@ -451,8 +451,8 @@ class Phi4MMForCausalLM(nn.Module): pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id]) return pattern.pad_input_tokens(input_ids, mm_inputs) - def should_apply_lora(self, module_name: str) -> Optional[str]: - return self.lora_pattern.match(module_name) + def should_apply_lora(self, module_name: str) -> bool: + return bool(self.lora_pattern.match(module_name)) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [