support Llama4 with non uniformed intermediate size across layers for… (#10047)

This commit is contained in:
gongwei-130
2025-09-05 17:28:15 -07:00
committed by GitHub
parent 273b28344b
commit ab62b135c1
7 changed files with 123 additions and 13 deletions

View File

@@ -104,12 +104,18 @@ class LoRAMemoryPool:
return all(_can_support(x) for x in config)
def get_lora_A_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
self,
module_name: str,
base_model: torch.nn.Module,
max_lora_dim: int,
layer_idx: int,
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
input_dim, _ = get_hidden_dim(
module_name, self.base_hf_config, base_model, layer_idx
)
c = get_stacked_multiply(module_name)
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
input_dim = divide(input_dim, self.tp_size)
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
)
def get_lora_B_shape(
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
self,
module_name: str,
base_model: torch.nn.Module,
max_lora_dim: int,
layer_idx: int,
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
_, output_dim = get_hidden_dim(
module_name, self.base_hf_config, base_model, layer_idx
)
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
output_dim = divide(output_dim, self.tp_size)
return (
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
def init_buffer(
buffer: Dict[str, List[torch.Tensor]],
target_modules: Set[str],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
):
for module_name in target_modules:
lora_shape = get_lora_shape_fn(
module_name, base_model, self.max_lora_rank
)
buffer[module_name] = [
torch.empty(
lora_shape,
get_lora_shape_fn(
module_name,
base_model,
self.max_lora_rank,
idx,
),
dtype=self.dtype,
device=device,
)
for _ in range(self.num_layer)
for idx in range(self.num_layer)
]
init_buffer(

View File

@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
if hasattr(base_model, "get_hidden_dim"):
return base_model.get_hidden_dim(module_name)
return base_model.get_hidden_dim(module_name, layer_idx)
else:
"""
WARNING: get_hidden_dim() is not defined,

View File

@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
def should_apply_lora(self, module_name: str) -> bool:
return bool(self.lora_pattern.match(module_name))
def get_hidden_dim(self, module_name):
def get_hidden_dim(self, module_name, layer_idx):
# return input_dim, output_dim
if module_name == "qkv_proj":
return (

View File

@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
return self.config.num_local_experts > 0
return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
def get_intermediate_size(self) -> int:
if isinstance(self.feed_forward, Llama4MoE):
return self.config.intermediate_size
else:
return self.config.intermediate_size_mlp
def forward(
self,
positions: torch.Tensor,
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
def get_input_embeddings(self):
return self.model.embed_tokens
def get_layers(self):
return self.model.layers
def _init_model(
self,
config: Llama4TextConfig,

View File

@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
def set_embed(self, embed):
return self.language_model.set_embed(embed)
def get_hidden_dim(self, module_name, layer_idx):
# return input_dim, output_dim
if module_name == "qkv_proj":
return (
self.config.hidden_size,
self.config.head_dim
* (
self.config.num_attention_heads
+ self.config.num_key_value_heads * 2
),
)
elif module_name == "o_proj":
return (
self.config.head_dim * self.config.num_attention_heads,
self.config.hidden_size,
)
elif module_name == "gate_up_proj":
return self.config.hidden_size, self.config.intermediate_size * 2
elif module_name == "down_proj":
decoder_layer = self.language_model.get_layers()[layer_idx]
intermediate_size = decoder_layer.get_intermediate_size()
return intermediate_size, self.config.hidden_size
else:
raise NotImplementedError()
EntryClass = Llama4ForConditionalGeneration