refactor: minor refactors regarding multimodal processing (#6187)

This commit is contained in:
Mick
2025-05-18 13:53:20 +08:00
committed by GitHub
parent b3f3d610fd
commit 01dd39bac1
15 changed files with 140 additions and 98 deletions

View File

@@ -19,6 +19,7 @@ import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
import torch
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
@@ -65,6 +66,43 @@ def download_from_hf(model_path: str):
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
No op for pure text models.
"""
if config.architectures is not None:
class_name = config.architectures[0]
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr(config, "torch_dtype", torch.float16)
return config
if hasattr(config, "text_config"):
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr(config.text_config, "num_attention_heads")
return config.text_config
if hasattr(config, "language_config"):
return config.language_config
if hasattr(config, "thinker_config"):
# qwen2.5 omni
thinker_config = config.thinker_config
if hasattr(thinker_config, "text_config"):
setattr(
thinker_config.text_config,
"torch_dtype",
getattr(thinker_config, "torch_dtype", None),
)
return thinker_config.text_config
return thinker_config
else:
return config
def get_config(
model: str,
trust_remote_code: bool,
@@ -80,13 +118,12 @@ def get_config(
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
)
text_config = get_hf_text_config(config=config)
# FIXME: Pour contents of janus-pro's langauge_config to first-level
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"):
assert hasattr(config, "language_config")
for key, val in config.language_config.__dict__.items():
setattr(config, key, val)
setattr(config, "architectures", ["MultiModalityCausalLM"])
if isinstance(model, str) and text_config is not None:
for key, val in text_config.__dict__.items():
if not hasattr(config, key) and getattr(text_config, key, None) is not None:
setattr(config, key, val)
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
@@ -99,6 +136,9 @@ def get_config(
if not hasattr(config, key):
setattr(config, key, val)
if config.model_type == "multi_modality":
config.update({"architectures": ["MultiModalityCausalLM"]})
if model_override_args:
config.update(model_override_args)