Update step3v default config (#8626)
This commit is contained in:
@@ -112,6 +112,7 @@ class ModelConfig:
|
||||
mm_disabled_models = [
|
||||
"Gemma3ForConditionalGeneration",
|
||||
"Llama4ForConditionalGeneration",
|
||||
"Step3VLForConditionalGeneration",
|
||||
]
|
||||
if self.hf_config.architectures[0] in mm_disabled_models:
|
||||
enable_multimodal = False
|
||||
|
||||
@@ -868,7 +868,6 @@ class Step3VLForConditionalGeneration(nn.Module):
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# TODO:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", 0),
|
||||
@@ -901,9 +900,7 @@ class Step3VLForConditionalGeneration(nn.Module):
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "vision_model" in name:
|
||||
# 1.It’s not great, but let’s leave it like this for now
|
||||
name = name.replace("self_attn", "self_attn.attn")
|
||||
# 2.
|
||||
name = name.replace("out_proj", "proj")
|
||||
|
||||
# TODO: support vision model
|
||||
|
||||
@@ -2344,6 +2344,7 @@ def is_fa3_default_architecture(hf_config):
|
||||
"Qwen3ForCausalLM",
|
||||
"Qwen3MoeForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
"Step3VLForConditionalGeneration",
|
||||
}
|
||||
return architectures[0] in default_archs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user