support Llama4 with non uniformed intermediate size across layers for… (#10047)
This commit is contained in:
@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
|
||||
|
||||
|
||||
def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
|
||||
if hasattr(base_model, "get_hidden_dim"):
|
||||
return base_model.get_hidden_dim(module_name)
|
||||
return base_model.get_hidden_dim(module_name, layer_idx)
|
||||
else:
|
||||
"""
|
||||
WARNING: get_hidden_dim() is not defined,
|
||||
|
||||
Reference in New Issue
Block a user