[Fix] Reduce memory usage for loading llava model & Remove EntryClassRemapping (#1308)

This commit is contained in:
Lianmin Zheng
2024-09-02 21:44:45 -07:00
committed by GitHub
parent a5a134f39f
commit f64eae3a29
17 changed files with 105 additions and 158 deletions

View File

@@ -297,7 +297,6 @@ class ExaoneForCausalLM(nn.Module):
config,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
efficient_weight_load=False,
) -> None:
super().__init__()
self.config = config
@@ -345,9 +344,7 @@ class ExaoneForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
return len(params_dict)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
@@ -358,7 +355,7 @@ class ExaoneForCausalLM(nn.Module):
]
params_dict = dict(self.named_parameters())
def load_weights_per_param(name, loaded_weight):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name or "projector" in name:
return
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -368,6 +365,7 @@ class ExaoneForCausalLM(nn.Module):
if name.startswith("model.vision_tower") and name not in params_dict:
return
name = name.replace("attn.attention", "self_attn")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
@@ -387,13 +385,5 @@ class ExaoneForCausalLM(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
if name is None or loaded_weight is None:
for name, loaded_weight in weights:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
else:
name = name.replace("attn.attention", "self_attn")
load_weights_per_param(name, loaded_weight)
EntryClass = ExaoneForCausalLM