support Llama4 with non uniformed intermediate size across layers for… (#10047)

This commit is contained in:
gongwei-130
2025-09-05 17:28:15 -07:00
committed by GitHub
parent 273b28344b
commit ab62b135c1
7 changed files with 123 additions and 13 deletions

View File

@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
def get_hidden_dim(
module_name: str, config: AutoConfig, base_model: torch.nn.Module
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
) -> Tuple[int]:
"""
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
"""
if hasattr(base_model, "get_hidden_dim"):
return base_model.get_hidden_dim(module_name)
return base_model.get_hidden_dim(module_name, layer_idx)
else:
"""
WARNING: get_hidden_dim() is not defined,