support Llama4 with non uniformed intermediate size across layers for… (#10047)
This commit is contained in:
@@ -104,12 +104,18 @@ class LoRAMemoryPool:
|
||||
return all(_can_support(x) for x in config)
|
||||
|
||||
def get_lora_A_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
self,
|
||||
module_name: str,
|
||||
base_model: torch.nn.Module,
|
||||
max_lora_dim: int,
|
||||
layer_idx: int,
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
input_dim, _ = get_hidden_dim(
|
||||
module_name, self.base_hf_config, base_model, layer_idx
|
||||
)
|
||||
c = get_stacked_multiply(module_name)
|
||||
if self.tp_size > 1 and module_name in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
input_dim = divide(input_dim, self.tp_size)
|
||||
@@ -120,12 +126,18 @@ class LoRAMemoryPool:
|
||||
)
|
||||
|
||||
def get_lora_B_shape(
|
||||
self, module_name: str, base_model: torch.nn.Module, max_lora_dim: int
|
||||
self,
|
||||
module_name: str,
|
||||
base_model: torch.nn.Module,
|
||||
max_lora_dim: int,
|
||||
layer_idx: int,
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
||||
_, output_dim = get_hidden_dim(
|
||||
module_name, self.base_hf_config, base_model, layer_idx
|
||||
)
|
||||
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
||||
output_dim = divide(output_dim, self.tp_size)
|
||||
return (
|
||||
@@ -140,19 +152,21 @@ class LoRAMemoryPool:
|
||||
def init_buffer(
|
||||
buffer: Dict[str, List[torch.Tensor]],
|
||||
target_modules: Set[str],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int], Tuple[int]],
|
||||
get_lora_shape_fn: Callable[[str, torch.nn.Module, int, int], Tuple[int]],
|
||||
):
|
||||
for module_name in target_modules:
|
||||
lora_shape = get_lora_shape_fn(
|
||||
module_name, base_model, self.max_lora_rank
|
||||
)
|
||||
buffer[module_name] = [
|
||||
torch.empty(
|
||||
lora_shape,
|
||||
get_lora_shape_fn(
|
||||
module_name,
|
||||
base_model,
|
||||
self.max_lora_rank,
|
||||
idx,
|
||||
),
|
||||
dtype=self.dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(self.num_layer)
|
||||
for idx in range(self.num_layer)
|
||||
]
|
||||
|
||||
init_buffer(
|
||||
|
||||
@@ -48,14 +48,14 @@ def get_layer_id(name: str) -> int:
|
||||
|
||||
|
||||
def get_hidden_dim(
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
||||
module_name: str, config: AutoConfig, base_model: torch.nn.Module, layer_idx: int
|
||||
) -> Tuple[int]:
|
||||
"""
|
||||
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
||||
"""
|
||||
|
||||
if hasattr(base_model, "get_hidden_dim"):
|
||||
return base_model.get_hidden_dim(module_name)
|
||||
return base_model.get_hidden_dim(module_name, layer_idx)
|
||||
else:
|
||||
"""
|
||||
WARNING: get_hidden_dim() is not defined,
|
||||
|
||||
Reference in New Issue
Block a user