Refactor LoRA handling to support adapter tensors in fused format (#6585)
This commit is contained in:
@@ -17,6 +17,7 @@
|
||||
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Iterable
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -392,6 +393,10 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
lora_pattern = re.compile(
|
||||
r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
@@ -446,6 +451,9 @@ class Phi4MMForCausalLM(nn.Module):
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens([im_token_id])
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def should_apply_lora(self, module_name: str) -> Optional[str]:
|
||||
return self.lora_pattern.match(module_name)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
|
||||
Reference in New Issue
Block a user