Refactor LoRA handling to support adapter tensors in fused format (#6585)

This commit is contained in:
Lifu Huang
2025-05-26 21:51:54 -07:00
committed by GitHub
parent 1a8f5f6836
commit 477a101cbd
6 changed files with 86 additions and 31 deletions

View File

@@ -17,6 +17,7 @@
import logging
import math
import re
from collections.abc import Iterable
from typing import List, Optional, Tuple
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
"gate_up_proj": ["gate_proj", "up_proj"],
}
lora_pattern = re.compile(
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
)
def __init__(
self,
config: PretrainedConfig,
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
return pattern.pad_input_tokens(input_ids, mm_inputs)
def should_apply_lora(self, module_name: str) -> Optional[str]:
return self.lora_pattern.match(module_name)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)