Dynamic model class loading (#101)
This commit is contained in:
@@ -318,3 +318,5 @@ class LlamaForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = LlamaForCausalLM
|
||||
|
||||
@@ -330,3 +330,5 @@ def monkey_path_clip_vision_embed_forward():
|
||||
"forward",
|
||||
clip_vision_embed_forward,
|
||||
)
|
||||
|
||||
EntryClass = LlavaLlamaForCausalLM
|
||||
|
||||
@@ -376,3 +376,5 @@ class MixtralForCausalLM(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = MixtralForCausalLM
|
||||
|
||||
@@ -258,3 +258,5 @@ class QWenLMHeadModel(nn.Module):
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
EntryClass = QWenLMHeadModel
|
||||
|
||||
Reference in New Issue
Block a user