168 lines
6.5 KiB
Python
168 lines
6.5 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from enum import IntEnum, auto
|
|
from typing import Optional
|
|
|
|
from transformers import PretrainedConfig
|
|
|
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
|
|
|
|
|
class AttentionArch(IntEnum):
|
|
MLA = auto()
|
|
MHA = auto()
|
|
|
|
|
|
class ModelConfig:
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
trust_remote_code: bool = True,
|
|
revision: Optional[str] = None,
|
|
context_length: Optional[int] = None,
|
|
model_override_args: Optional[dict] = None,
|
|
) -> None:
|
|
self.path = path
|
|
self.trust_remote_code = trust_remote_code
|
|
self.revision = revision
|
|
self.model_override_args = model_override_args
|
|
self.hf_config = get_config(
|
|
self.path,
|
|
trust_remote_code,
|
|
revision,
|
|
model_override_args=model_override_args,
|
|
)
|
|
self.hf_text_config = get_hf_text_config(self.hf_config)
|
|
if context_length is not None:
|
|
self.context_len = context_length
|
|
else:
|
|
self.context_len = get_context_length(self.hf_config)
|
|
|
|
# Unify the config keys for hf_config
|
|
self.head_dim = getattr(
|
|
self.hf_config,
|
|
"head_dim",
|
|
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
|
|
)
|
|
|
|
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
|
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
|
self.head_dim = 256
|
|
self.attention_arch = AttentionArch.MLA
|
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
|
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
|
self.head_dim = 128
|
|
self.attention_arch = AttentionArch.MLA
|
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
|
else:
|
|
self.attention_arch = AttentionArch.MHA
|
|
|
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
|
|
|
# for Dbrx and MPT models
|
|
if self.hf_config.model_type in ["dbrx", "mpt"]:
|
|
self.num_key_value_heads = getattr(
|
|
self.hf_config.attn_config, "kv_n_heads", None
|
|
)
|
|
|
|
if self.num_key_value_heads is None:
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
self.hidden_size = self.hf_config.hidden_size
|
|
self.num_hidden_layers = self.hf_config.num_hidden_layers
|
|
self.vocab_size = self.hf_config.vocab_size
|
|
|
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
|
|
def get_total_num_kv_heads(self) -> int:
|
|
"""Returns the total number of KV heads."""
|
|
# For GPTBigCode & Falcon:
|
|
# NOTE: for falcon, when new_decoder_architecture is True, the
|
|
# multi_query flag is ignored and we use n_head_kv for the number of
|
|
# KV heads.
|
|
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
|
new_decoder_arch_falcon = (
|
|
self.hf_config.model_type in falcon_model_types
|
|
and getattr(self.hf_config, "new_decoder_architecture", False)
|
|
)
|
|
if not new_decoder_arch_falcon and getattr(
|
|
self.hf_text_config, "multi_query", False
|
|
):
|
|
# Multi-query attention, only one KV head.
|
|
# Currently, tensor parallelism is not supported in this case.
|
|
return 1
|
|
|
|
# For DBRX and MPT
|
|
if self.hf_config.model_type in ["mpt"]:
|
|
if "kv_n_heads" in self.hf_config.attn_config:
|
|
return self.hf_config.attn_config["kv_n_heads"]
|
|
return self.hf_config.num_attention_heads
|
|
if self.hf_config.model_type in ["dbrx"]:
|
|
return getattr(
|
|
self.hf_config.attn_config,
|
|
"kv_n_heads",
|
|
self.hf_config.num_attention_heads,
|
|
)
|
|
|
|
attributes = [
|
|
# For Falcon:
|
|
"n_head_kv",
|
|
"num_kv_heads",
|
|
# For LLaMA-2:
|
|
"num_key_value_heads",
|
|
# For ChatGLM:
|
|
"multi_query_group_num",
|
|
]
|
|
for attr in attributes:
|
|
num_kv_heads = getattr(self.hf_text_config, attr, None)
|
|
if num_kv_heads is not None:
|
|
return num_kv_heads
|
|
|
|
# For non-grouped-query attention models, the number of KV heads is
|
|
# equal to the number of attention heads.
|
|
return self.hf_text_config.num_attention_heads
|
|
|
|
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
|
|
def get_num_kv_heads(self, tensor_parallel_size) -> int:
|
|
"""Returns the number of KV heads per GPU."""
|
|
total_num_kv_heads = self.get_total_num_kv_heads()
|
|
# If tensor parallelism is used, we divide the number of KV heads by
|
|
# the tensor parallel size. We will replicate the KV heads in the
|
|
# case where the number of KV heads is smaller than the tensor
|
|
# parallel size so each GPU has at least one KV head.
|
|
return max(1, total_num_kv_heads // tensor_parallel_size)
|
|
|
|
|
|
def get_hf_text_config(config: PretrainedConfig):
|
|
"""Get the "sub" config relevant to llm for multi modal models.
|
|
No op for pure text models.
|
|
"""
|
|
class_name = config.architectures[0]
|
|
if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
|
|
# We support non-hf version of llava models, so we do not want to
|
|
# read the wrong values from the unused default text_config.
|
|
return config
|
|
|
|
if hasattr(config, "text_config"):
|
|
# The code operates under the assumption that text_config should have
|
|
# `num_attention_heads` (among others). Assert here to fail early
|
|
# if transformers config doesn't align with this assumption.
|
|
assert hasattr(config.text_config, "num_attention_heads")
|
|
return config.text_config
|
|
else:
|
|
return config
|