Update step3v default config (#8626)

This commit is contained in:
Ke Bao
2025-08-01 00:49:26 +08:00
committed by GitHub
parent 3c307dc057
commit 8fbcfd0723
3 changed files with 2 additions and 3 deletions

View File

@@ -112,6 +112,7 @@ class ModelConfig:
mm_disabled_models = [
"Gemma3ForConditionalGeneration",
"Llama4ForConditionalGeneration",
"Step3VLForConditionalGeneration",
]
if self.hf_config.architectures[0] in mm_disabled_models:
enable_multimodal = False

View File

@@ -868,7 +868,6 @@ class Step3VLForConditionalGeneration(nn.Module):
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# TODO:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", 0),
@@ -901,9 +900,7 @@ class Step3VLForConditionalGeneration(nn.Module):
for name, loaded_weight in weights:
if "vision_model" in name:
# 1.Its not great, but lets leave it like this for now
name = name.replace("self_attn", "self_attn.attn")
# 2.
name = name.replace("out_proj", "proj")
# TODO: support vision model

View File

@@ -2344,6 +2344,7 @@ def is_fa3_default_architecture(hf_config):
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"Glm4MoeForCausalLM",
"Step3VLForConditionalGeneration",
}
return architectures[0] in default_archs