model: Support Janus-pro (#3203)
This commit is contained in:
@@ -408,7 +408,7 @@ def _get_and_verify_dtype(
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 1. Check the model architecture
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
@@ -424,18 +424,25 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
||||
return not is_embedding
|
||||
|
||||
|
||||
multimodal_model_archs = [
|
||||
"LlavaLlamaForCausalLM",
|
||||
"LlavaQwenForCausalLM",
|
||||
"LlavaMistralForCausalLM",
|
||||
"LlavaVidForCausalLM",
|
||||
"Grok1VForCausalLM",
|
||||
"Grok1AForCausalLM",
|
||||
"MllamaForConditionalGeneration",
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"MiniCPMV",
|
||||
"MultiModalityCausalLM",
|
||||
]
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures: List[str]):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "Grok1VForCausalLM" in model_architectures
|
||||
or "Grok1AForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
or "Qwen2_5_VLForConditionalGeneration" in model_architectures
|
||||
or "MiniCPMV" in model_architectures
|
||||
if any(
|
||||
multi_model_arch in model_architectures
|
||||
for multi_model_arch in multimodal_model_archs
|
||||
):
|
||||
return True
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user