refactor: minor refactors regarding multimodal processing (#6187)
This commit is contained in:
@@ -19,6 +19,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -65,6 +66,43 @@ def download_from_hf(model_path: str):
|
||||
return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
|
||||
|
||||
|
||||
def get_hf_text_config(config: PretrainedConfig):
|
||||
"""Get the "sub" config relevant to llm for multi modal models.
|
||||
No op for pure text models.
|
||||
"""
|
||||
if config.architectures is not None:
|
||||
class_name = config.architectures[0]
|
||||
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
||||
# We support non-hf version of llava models, so we do not want to
|
||||
# read the wrong values from the unused default text_config.
|
||||
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
|
||||
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
|
||||
setattr(config, "torch_dtype", torch.float16)
|
||||
return config
|
||||
|
||||
if hasattr(config, "text_config"):
|
||||
# The code operates under the assumption that text_config should have
|
||||
# `num_attention_heads` (among others). Assert here to fail early
|
||||
# if transformers config doesn't align with this assumption.
|
||||
assert hasattr(config.text_config, "num_attention_heads")
|
||||
return config.text_config
|
||||
if hasattr(config, "language_config"):
|
||||
return config.language_config
|
||||
if hasattr(config, "thinker_config"):
|
||||
# qwen2.5 omni
|
||||
thinker_config = config.thinker_config
|
||||
if hasattr(thinker_config, "text_config"):
|
||||
setattr(
|
||||
thinker_config.text_config,
|
||||
"torch_dtype",
|
||||
getattr(thinker_config, "torch_dtype", None),
|
||||
)
|
||||
return thinker_config.text_config
|
||||
return thinker_config
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def get_config(
|
||||
model: str,
|
||||
trust_remote_code: bool,
|
||||
@@ -80,13 +118,12 @@ def get_config(
|
||||
config = AutoConfig.from_pretrained(
|
||||
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
||||
)
|
||||
text_config = get_hf_text_config(config=config)
|
||||
|
||||
# FIXME: Pour contents of janus-pro's langauge_config to first-level
|
||||
if isinstance(model, str) and model.lower().startswith("deepseek-ai/janus-pro"):
|
||||
assert hasattr(config, "language_config")
|
||||
for key, val in config.language_config.__dict__.items():
|
||||
setattr(config, key, val)
|
||||
setattr(config, "architectures", ["MultiModalityCausalLM"])
|
||||
if isinstance(model, str) and text_config is not None:
|
||||
for key, val in text_config.__dict__.items():
|
||||
if not hasattr(config, key) and getattr(text_config, key, None) is not None:
|
||||
setattr(config, key, val)
|
||||
|
||||
if config.model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[config.model_type]
|
||||
@@ -99,6 +136,9 @@ def get_config(
|
||||
if not hasattr(config, key):
|
||||
setattr(config, key, val)
|
||||
|
||||
if config.model_type == "multi_modality":
|
||||
config.update({"architectures": ["MultiModalityCausalLM"]})
|
||||
|
||||
if model_override_args:
|
||||
config.update(model_override_args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user