Unify the model type checking (#1905)
This commit is contained in:
@@ -204,56 +204,6 @@ def is_port_available(port):
|
||||
return False
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_attention_free_model(model_architectures):
|
||||
return False
|
||||
|
||||
|
||||
def model_has_inner_state(model_architectures):
|
||||
return False
|
||||
|
||||
|
||||
def is_embedding_model(model_architectures):
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_generation_model(model_architectures, is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
def decode_video_base64(video_base64):
|
||||
from PIL import Image
|
||||
|
||||
|
||||
Reference in New Issue
Block a user