Dynamic model class loading (#101)

This commit is contained in:
Cody Yu
2024-01-25 15:29:07 -08:00
committed by GitHub
parent 0147f940dd
commit 3a581e9949
6 changed files with 40 additions and 28 deletions

View File

@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = LlamaForCausalLM

View File

@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaLlamaForCausalLM

View File

@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = MixtralForCausalLM

View File

@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = QWenLMHeadModel