Unify the model type checking (#1905)
This commit is contained in:
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import IntEnum, auto
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
@@ -38,18 +39,24 @@ class ModelConfig:
|
||||
revision: Optional[str] = None,
|
||||
context_length: Optional[int] = None,
|
||||
model_override_args: Optional[dict] = None,
|
||||
is_embedding: Optional[bool] = None
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.revision = revision
|
||||
self.model_override_args = model_override_args
|
||||
# Parse args
|
||||
self.model_override_args = json.loads(model_override_args)
|
||||
self.hf_config = get_config(
|
||||
self.path,
|
||||
trust_remote_code,
|
||||
revision,
|
||||
model_override_args=model_override_args,
|
||||
path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
model_override_args=self.model_override_args,
|
||||
)
|
||||
self.hf_text_config = get_hf_text_config(self.hf_config)
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
|
||||
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
|
||||
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
|
||||
|
||||
# Derive context length
|
||||
derived_context_len = get_context_length(self.hf_text_config)
|
||||
allow_long_context = os.environ.get(
|
||||
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
|
||||
@@ -81,7 +88,7 @@ class ModelConfig:
|
||||
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
|
||||
)
|
||||
|
||||
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
||||
# FIXME: temporary special judge for MLA architecture
|
||||
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
@@ -112,8 +119,6 @@ class ModelConfig:
|
||||
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
|
||||
self.vocab_size = self.hf_text_config.vocab_size
|
||||
|
||||
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
@@ -163,7 +168,6 @@ class ModelConfig:
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_text_config.num_attention_heads
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
||||
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
|
||||
return config.text_config
|
||||
else:
|
||||
return config
|
||||
|
||||
|
||||
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
|
||||
# We have two ways to determine whether a model is a generative model.
|
||||
# 1. Check the model architectue
|
||||
# 2. check the `is_embedding` server args
|
||||
|
||||
if (
|
||||
"LlamaEmbeddingModel" in model_architectures
|
||||
or "MistralModel" in model_architectures
|
||||
or "LlamaForSequenceClassification" in model_architectures
|
||||
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return not is_embedding
|
||||
|
||||
|
||||
def is_multimodal_model(model_architectures: List[str]):
|
||||
if (
|
||||
"LlavaLlamaForCausalLM" in model_architectures
|
||||
or "LlavaQwenForCausalLM" in model_architectures
|
||||
or "LlavaMistralForCausalLM" in model_architectures
|
||||
or "LlavaVidForCausalLM" in model_architectures
|
||||
or "MllamaForConditionalGeneration" in model_architectures
|
||||
or "Qwen2VLForConditionalGeneration" in model_architectures
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
Reference in New Issue
Block a user