Unify the model type checking (#1905)

This commit is contained in:
Lianmin Zheng
2024-11-03 12:25:39 -08:00
committed by GitHub
parent c17c578108
commit 0abbf289a8
13 changed files with 146 additions and 160 deletions

View File

@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import json
import logging
import os
from enum import IntEnum, auto
from typing import Optional
from typing import List, Optional
from transformers import PretrainedConfig
@@ -38,18 +39,24 @@ class ModelConfig:
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_override_args: Optional[dict] = None,
is_embedding: Optional[bool] = None
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.model_override_args = model_override_args
# Parse args
self.model_override_args = json.loads(model_override_args)
self.hf_config = get_config(
self.path,
trust_remote_code,
revision,
model_override_args=model_override_args,
path,
trust_remote_code=trust_remote_code,
revision=revision,
model_override_args=self.model_override_args,
)
self.hf_text_config = get_hf_text_config(self.hf_config)
# Check model type
self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
# Derive context length
derived_context_len = get_context_length(self.hf_text_config)
allow_long_context = os.environ.get(
"SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
@@ -81,7 +88,7 @@ class ModelConfig:
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
)
# FIXME: temporary special judge for deepseek v2 MLA architecture
# FIXME: temporary special judge for MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
@@ -112,8 +119,6 @@ class ModelConfig:
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
self.vocab_size = self.hf_text_config.vocab_size
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads(self) -> int:
"""Returns the total number of KV heads."""
@@ -163,7 +168,6 @@ class ModelConfig:
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def get_num_kv_heads(self, tensor_parallel_size) -> int:
"""Returns the number of KV heads per GPU."""
total_num_kv_heads = self.get_total_num_kv_heads()
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config
def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
"LlamaEmbeddingModel" in model_architectures
or "MistralModel" in model_architectures
or "LlamaForSequenceClassification" in model_architectures
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
):
return False
else:
return not is_embedding
def is_multimodal_model(model_architectures: List[str]):
if (
"LlavaLlamaForCausalLM" in model_architectures
or "LlavaQwenForCausalLM" in model_architectures
or "LlavaMistralForCausalLM" in model_architectures
or "LlavaVidForCausalLM" in model_architectures
or "MllamaForConditionalGeneration" in model_architectures
or "Qwen2VLForConditionalGeneration" in model_architectures
):
return True
else:
return False
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures