2024-07-28 23:07:12 +10:00
"""
Copyright 2023 - 2024 SGLang Team
Licensed under the Apache License , Version 2.0 ( the " License " ) ;
you may not use this file except in compliance with the License .
You may obtain a copy of the License at
http : / / www . apache . org / licenses / LICENSE - 2.0
Unless required by applicable law or agreed to in writing , software
distributed under the License is distributed on an " AS IS " BASIS ,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND , either express or implied .
See the License for the specific language governing permissions and
limitations under the License .
"""
2024-10-25 03:40:36 +08:00
import logging
import os
2024-08-05 01:40:33 +08:00
from enum import IntEnum , auto
2024-03-11 12:14:27 +08:00
from typing import Optional
2024-01-08 04:37:50 +00:00
2024-06-12 07:39:52 +08:00
from transformers import PretrainedConfig
2024-01-08 04:37:50 +00:00
2024-06-12 21:48:40 -07:00
from sglang . srt . hf_transformers_utils import get_config , get_context_length
2024-10-25 03:40:36 +08:00
logger = logging . getLogger ( __name__ )
2024-01-08 04:37:50 +00:00
2024-08-05 01:40:33 +08:00
class AttentionArch ( IntEnum ) :
MLA = auto ( )
MHA = auto ( )
2024-01-08 04:37:50 +00:00
class ModelConfig :
def __init__ (
self ,
path : str ,
trust_remote_code : bool = True ,
revision : Optional [ str ] = None ,
2024-02-20 18:22:56 -06:00
context_length : Optional [ int ] = None ,
2024-09-01 03:14:56 -07:00
model_override_args : Optional [ dict ] = None ,
2024-01-08 04:37:50 +00:00
) - > None :
self . path = path
self . trust_remote_code = trust_remote_code
self . revision = revision
2024-09-01 03:14:56 -07:00
self . model_override_args = model_override_args
2024-06-12 21:48:40 -07:00
self . hf_config = get_config (
self . path ,
trust_remote_code ,
revision ,
2024-09-01 03:14:56 -07:00
model_override_args = model_override_args ,
2024-06-12 21:48:40 -07:00
)
2024-06-12 07:39:52 +08:00
self . hf_text_config = get_hf_text_config ( self . hf_config )
2024-10-25 03:40:36 +08:00
derived_context_len = get_context_length ( self . hf_text_config )
allow_long_context = os . environ . get (
" SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN " , None
)
2024-02-20 18:22:56 -06:00
if context_length is not None :
2024-10-25 03:40:36 +08:00
if context_length > derived_context_len :
if allow_long_context :
logger . warning (
f " Warning: User-specified context_length ( { context_length } ) is greater than the derived context_length ( { derived_context_len } ). "
f " This may lead to incorrect model outputs or CUDA errors. "
)
self . context_len = context_length
else :
raise ValueError (
f " User-specified context_length ( { context_length } ) is greater than the derived context_length ( { derived_context_len } ). "
f " This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model ' s config. "
f " To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 "
)
else :
self . context_len = context_length
2024-02-20 18:22:56 -06:00
else :
2024-10-25 03:40:36 +08:00
self . context_len = derived_context_len
2024-01-08 04:37:50 +00:00
2024-09-27 01:49:16 -07:00
# Unify the config keys for hf_text_config
2024-03-11 12:14:27 +08:00
self . head_dim = getattr (
2024-09-27 01:49:16 -07:00
self . hf_text_config ,
2024-03-11 12:14:27 +08:00
" head_dim " ,
2024-09-27 01:49:16 -07:00
self . hf_text_config . hidden_size / / self . hf_text_config . num_attention_heads ,
2024-03-11 12:14:27 +08:00
)
2024-07-26 17:10:07 -07:00
# FIXME: temporary special judge for deepseek v2 MLA architecture
if " DeepseekV2ForCausalLM " in self . hf_config . architectures :
self . head_dim = 256
2024-08-05 01:40:33 +08:00
self . attention_arch = AttentionArch . MLA
self . kv_lora_rank = self . hf_config . kv_lora_rank
self . qk_rope_head_dim = self . hf_config . qk_rope_head_dim
2024-09-10 17:57:52 +08:00
elif " MiniCPM3ForCausalLM " in self . hf_config . architectures :
self . head_dim = 128
self . attention_arch = AttentionArch . MLA
self . kv_lora_rank = self . hf_config . kv_lora_rank
self . qk_rope_head_dim = self . hf_config . qk_rope_head_dim
2024-08-05 01:40:33 +08:00
else :
self . attention_arch = AttentionArch . MHA
2024-07-26 17:10:07 -07:00
2024-09-27 01:49:16 -07:00
self . num_attention_heads = self . hf_text_config . num_attention_heads
self . num_key_value_heads = getattr (
self . hf_text_config , " num_key_value_heads " , None
)
2024-03-29 01:05:19 +08:00
# for Dbrx and MPT models
if self . hf_config . model_type in [ " dbrx " , " mpt " ] :
self . num_key_value_heads = getattr (
self . hf_config . attn_config , " kv_n_heads " , None
)
2024-01-23 12:14:51 +08:00
if self . num_key_value_heads is None :
self . num_key_value_heads = self . num_attention_heads
2024-09-27 01:49:16 -07:00
self . hidden_size = self . hf_text_config . hidden_size
self . num_hidden_layers = self . hf_text_config . num_hidden_layers
self . vocab_size = self . hf_text_config . vocab_size
2024-06-12 07:39:52 +08:00
2024-10-21 15:01:21 -07:00
self . is_encoder_decoder = self . hf_config . model_type in [ " mllama " ]
2024-06-12 07:39:52 +08:00
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads ( self ) - > int :
""" Returns the total number of KV heads. """
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = [ " falcon " , " RefinedWeb " , " RefinedWebModel " ]
new_decoder_arch_falcon = (
2024-06-12 21:48:40 -07:00
self . hf_config . model_type in falcon_model_types
and getattr ( self . hf_config , " new_decoder_architecture " , False )
)
if not new_decoder_arch_falcon and getattr (
self . hf_text_config , " multi_query " , False
) :
2024-06-12 07:39:52 +08:00
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
2024-06-17 20:41:24 -07:00
if self . hf_config . model_type in [ " mpt " ] :
if " kv_n_heads " in self . hf_config . attn_config :
return self . hf_config . attn_config [ " kv_n_heads " ]
return self . hf_config . num_attention_heads
if self . hf_config . model_type in [ " dbrx " ] :
2024-06-12 21:48:40 -07:00
return getattr (
self . hf_config . attn_config ,
" kv_n_heads " ,
self . hf_config . num_attention_heads ,
)
2024-06-12 07:39:52 +08:00
attributes = [
# For Falcon:
" n_head_kv " ,
" num_kv_heads " ,
# For LLaMA-2:
" num_key_value_heads " ,
# For ChatGLM:
" multi_query_group_num " ,
]
for attr in attributes :
num_kv_heads = getattr ( self . hf_text_config , attr , None )
if num_kv_heads is not None :
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self . hf_text_config . num_attention_heads
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
def get_num_kv_heads ( self , tensor_parallel_size ) - > int :
""" Returns the number of KV heads per GPU. """
total_num_kv_heads = self . get_total_num_kv_heads ( )
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
2024-06-12 21:48:40 -07:00
return max ( 1 , total_num_kv_heads / / tensor_parallel_size )
2024-06-12 07:39:52 +08:00
def get_hf_text_config ( config : PretrainedConfig ) :
""" Get the " sub " config relevant to llm for multi modal models.
2024-06-12 21:48:40 -07:00
No op for pure text models .
2024-06-12 07:39:52 +08:00
"""
2024-07-06 00:58:46 -07:00
class_name = config . architectures [ 0 ]
if class_name . startswith ( " Llava " ) and class_name . endswith ( " ForCausalLM " ) :
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
return config
2024-06-12 07:39:52 +08:00
if hasattr ( config , " text_config " ) :
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr ( config . text_config , " num_attention_heads " )
return config . text_config
else :
return config