2024-11-22 22:16:53 +08:00
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
2024-07-28 23:07:12 +10:00
2024-11-03 12:25:39 -08:00
import json
2024-10-25 03:40:36 +08:00
import logging
2024-08-05 01:40:33 +08:00
from enum import IntEnum , auto
2024-12-28 02:59:56 +08:00
from typing import List , Optional , Set , Union
2024-01-08 04:37:50 +00:00
2024-12-02 23:22:13 +08:00
import torch
2024-06-12 07:39:52 +08:00
from transformers import PretrainedConfig
2024-01-08 04:37:50 +00:00
2024-06-12 21:48:40 -07:00
from sglang . srt . hf_transformers_utils import get_config , get_context_length
2024-12-02 23:22:13 +08:00
from sglang . srt . layers . quantization import QUANTIZATION_METHODS
from sglang . srt . utils import get_bool_env_var , is_hip
2024-06-12 21:48:40 -07:00
2024-10-25 03:40:36 +08:00
logger = logging . getLogger ( __name__ )
2024-01-08 04:37:50 +00:00
2024-08-05 01:40:33 +08:00
class AttentionArch ( IntEnum ) :
MLA = auto ( )
MHA = auto ( )
2024-01-08 04:37:50 +00:00
class ModelConfig :
def __init__ (
self ,
2024-12-02 23:22:13 +08:00
model_path : str ,
2024-01-08 04:37:50 +00:00
trust_remote_code : bool = True ,
revision : Optional [ str ] = None ,
2024-02-20 18:22:56 -06:00
context_length : Optional [ int ] = None ,
2024-09-01 03:14:56 -07:00
model_override_args : Optional [ dict ] = None ,
2024-11-07 15:42:47 -08:00
is_embedding : Optional [ bool ] = None ,
2024-12-02 23:22:13 +08:00
dtype : str = " auto " ,
quantization : Optional [ str ] = None ,
2024-01-08 04:37:50 +00:00
) - > None :
2024-12-02 23:22:13 +08:00
self . model_path = model_path
self . revision = revision
self . quantization = quantization
2024-12-27 11:23:46 -08:00
2024-11-03 12:25:39 -08:00
# Parse args
self . model_override_args = json . loads ( model_override_args )
2024-06-12 21:48:40 -07:00
self . hf_config = get_config (
2024-12-02 23:22:13 +08:00
model_path ,
2024-11-03 12:25:39 -08:00
trust_remote_code = trust_remote_code ,
revision = revision ,
model_override_args = self . model_override_args ,
2024-06-12 21:48:40 -07:00
)
2024-06-12 07:39:52 +08:00
self . hf_text_config = get_hf_text_config ( self . hf_config )
2024-11-03 12:25:39 -08:00
# Check model type
2024-11-07 15:42:47 -08:00
self . is_generation = is_generation_model (
self . hf_config . architectures , is_embedding
)
2024-11-03 12:25:39 -08:00
self . is_multimodal = is_multimodal_model ( self . hf_config . architectures )
self . is_encoder_decoder = is_encoder_decoder_model ( self . hf_config . architectures )
2024-12-02 23:22:13 +08:00
self . dtype = _get_and_verify_dtype ( self . hf_text_config , dtype )
2024-11-03 12:25:39 -08:00
# Derive context length
2024-10-25 03:40:36 +08:00
derived_context_len = get_context_length ( self . hf_text_config )
2024-02-20 18:22:56 -06:00
if context_length is not None :
2024-10-25 03:40:36 +08:00
if context_length > derived_context_len :
2024-11-27 02:52:46 -08:00
if get_bool_env_var ( " SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN " ) :
2024-10-25 03:40:36 +08:00
logger . warning (
f " Warning: User-specified context_length ( { context_length } ) is greater than the derived context_length ( { derived_context_len } ). "
f " This may lead to incorrect model outputs or CUDA errors. "
)
self . context_len = context_length
else :
raise ValueError (
f " User-specified context_length ( { context_length } ) is greater than the derived context_length ( { derived_context_len } ). "
f " This may lead to incorrect model outputs or CUDA errors. Note that the derived context_length may differ from max_position_embeddings in the model ' s config. "
f " To allow overriding this maximum, set the env var SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN=1 "
)
else :
self . context_len = context_length
2024-02-20 18:22:56 -06:00
else :
2024-10-25 03:40:36 +08:00
self . context_len = derived_context_len
2024-01-08 04:37:50 +00:00
2024-09-27 01:49:16 -07:00
# Unify the config keys for hf_text_config
2024-03-11 12:14:27 +08:00
self . head_dim = getattr (
2024-09-27 01:49:16 -07:00
self . hf_text_config ,
2024-03-11 12:14:27 +08:00
" head_dim " ,
2024-09-27 01:49:16 -07:00
self . hf_text_config . hidden_size / / self . hf_text_config . num_attention_heads ,
2024-03-11 12:14:27 +08:00
)
2024-07-26 17:10:07 -07:00
2024-11-03 12:25:39 -08:00
# FIXME: temporary special judge for MLA architecture
2024-12-26 00:02:14 +08:00
if (
" DeepseekV2ForCausalLM " in self . hf_config . architectures
or " DeepseekV3ForCausalLM " in self . hf_config . architectures
) :
2024-07-26 17:10:07 -07:00
self . head_dim = 256
2024-08-05 01:40:33 +08:00
self . attention_arch = AttentionArch . MLA
self . kv_lora_rank = self . hf_config . kv_lora_rank
self . qk_rope_head_dim = self . hf_config . qk_rope_head_dim
2024-09-10 17:57:52 +08:00
elif " MiniCPM3ForCausalLM " in self . hf_config . architectures :
self . head_dim = 128
self . attention_arch = AttentionArch . MLA
self . kv_lora_rank = self . hf_config . kv_lora_rank
self . qk_rope_head_dim = self . hf_config . qk_rope_head_dim
2024-08-05 01:40:33 +08:00
else :
self . attention_arch = AttentionArch . MHA
2024-07-26 17:10:07 -07:00
2024-09-27 01:49:16 -07:00
self . num_attention_heads = self . hf_text_config . num_attention_heads
self . num_key_value_heads = getattr (
self . hf_text_config , " num_key_value_heads " , None
)
2024-03-29 01:05:19 +08:00
# for Dbrx and MPT models
if self . hf_config . model_type in [ " dbrx " , " mpt " ] :
self . num_key_value_heads = getattr (
self . hf_config . attn_config , " kv_n_heads " , None
)
2024-01-23 12:14:51 +08:00
if self . num_key_value_heads is None :
self . num_key_value_heads = self . num_attention_heads
2024-09-27 01:49:16 -07:00
self . hidden_size = self . hf_text_config . hidden_size
self . num_hidden_layers = self . hf_text_config . num_hidden_layers
self . vocab_size = self . hf_text_config . vocab_size
2024-06-12 07:39:52 +08:00
2024-12-22 06:25:57 -08:00
# Veirfy quantization
2024-12-02 23:22:13 +08:00
self . _verify_quantization ( )
2024-12-27 11:23:46 -08:00
# Text attrs
self . hf_eos_token_id = self . get_hf_eos_token_id ( )
2024-12-22 06:25:57 -08:00
# Multimodel attrs
self . image_token_id = getattr ( self . hf_config , " image_token_id " , None )
2024-06-12 07:39:52 +08:00
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
def get_total_num_kv_heads ( self ) - > int :
""" Returns the total number of KV heads. """
# For GPTBigCode & Falcon:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
falcon_model_types = [ " falcon " , " RefinedWeb " , " RefinedWebModel " ]
new_decoder_arch_falcon = (
2024-06-12 21:48:40 -07:00
self . hf_config . model_type in falcon_model_types
and getattr ( self . hf_config , " new_decoder_architecture " , False )
)
if not new_decoder_arch_falcon and getattr (
self . hf_text_config , " multi_query " , False
) :
2024-06-12 07:39:52 +08:00
# Multi-query attention, only one KV head.
# Currently, tensor parallelism is not supported in this case.
return 1
# For DBRX and MPT
2024-06-17 20:41:24 -07:00
if self . hf_config . model_type in [ " mpt " ] :
if " kv_n_heads " in self . hf_config . attn_config :
return self . hf_config . attn_config [ " kv_n_heads " ]
return self . hf_config . num_attention_heads
if self . hf_config . model_type in [ " dbrx " ] :
2024-06-12 21:48:40 -07:00
return getattr (
self . hf_config . attn_config ,
" kv_n_heads " ,
self . hf_config . num_attention_heads ,
)
2024-06-12 07:39:52 +08:00
attributes = [
# For Falcon:
" n_head_kv " ,
" num_kv_heads " ,
# For LLaMA-2:
" num_key_value_heads " ,
# For ChatGLM:
" multi_query_group_num " ,
]
for attr in attributes :
num_kv_heads = getattr ( self . hf_text_config , attr , None )
if num_kv_heads is not None :
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self . hf_text_config . num_attention_heads
def get_num_kv_heads ( self , tensor_parallel_size ) - > int :
""" Returns the number of KV heads per GPU. """
total_num_kv_heads = self . get_total_num_kv_heads ( )
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
# case where the number of KV heads is smaller than the tensor
# parallel size so each GPU has at least one KV head.
2024-06-12 21:48:40 -07:00
return max ( 1 , total_num_kv_heads / / tensor_parallel_size )
2024-06-12 07:39:52 +08:00
2024-12-02 23:22:13 +08:00
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _parse_quant_hf_config ( self ) :
quant_cfg = getattr ( self . hf_config , " quantization_config " , None )
if quant_cfg is None :
# compressed-tensors uses a "compression_config" key
quant_cfg = getattr ( self . hf_config , " compression_config " , None )
return quant_cfg
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization ( self ) - > None :
supported_quantization = [ * QUANTIZATION_METHODS ]
rocm_supported_quantization = [
" awq " ,
" gptq " ,
" fp8 " ,
" compressed_tensors " ,
" compressed-tensors " ,
" fbgemm_fp8 " ,
]
optimized_quantization_methods = [
" fp8 " ,
" marlin " ,
" modelopt " ,
" gptq_marlin_24 " ,
" gptq_marlin " ,
" awq_marlin " ,
" fbgemm_fp8 " ,
" compressed_tensors " ,
" compressed-tensors " ,
" experts_int8 " ,
]
if self . quantization is not None :
self . quantization = self . quantization . lower ( )
# Parse quantization method from the HF model config, if available.
quant_cfg = self . _parse_quant_hf_config ( )
if quant_cfg is not None :
quant_method = quant_cfg . get ( " quant_method " , " " ) . lower ( )
# Detect which checkpoint is it
for _ , method in QUANTIZATION_METHODS . items ( ) :
quantization_override = method . override_quantization_method (
quant_cfg , self . quantization
)
if quantization_override :
quant_method = quantization_override
self . quantization = quantization_override
break
# Verify quantization configurations.
if self . quantization is None :
self . quantization = quant_method
elif self . quantization != quant_method :
raise ValueError (
" Quantization method specified in the model config "
f " ( { quant_method } ) does not match the quantization "
f " method specified in the `quantization` argument "
f " ( { self . quantization } ). "
)
if self . quantization is not None :
if self . quantization not in supported_quantization :
raise ValueError (
f " Unknown quantization method: { self . quantization } . Must "
f " be one of { supported_quantization } . "
)
if is_hip ( ) and self . quantization not in rocm_supported_quantization :
raise ValueError (
f " { self . quantization } quantization is currently not "
f " supported in ROCm. "
)
if self . quantization not in optimized_quantization_methods :
logger . warning (
" %s quantization is not fully "
" optimized yet. The speed can be slower than "
" non-quantized models. " ,
self . quantization ,
)
2024-12-28 02:59:56 +08:00
def get_hf_eos_token_id ( self ) - > Optional [ Set [ int ] ] :
eos_ids = getattr ( self . hf_config , " eos_token_id " , None )
if eos_ids :
# it can be either int or list of int
eos_ids = { eos_ids } if isinstance ( eos_ids , int ) else set ( eos_ids )
return eos_ids
2024-06-12 07:39:52 +08:00
def get_hf_text_config ( config : PretrainedConfig ) :
""" Get the " sub " config relevant to llm for multi modal models.
2024-06-12 21:48:40 -07:00
No op for pure text models .
2024-06-12 07:39:52 +08:00
"""
2024-07-06 00:58:46 -07:00
class_name = config . architectures [ 0 ]
if class_name . startswith ( " Llava " ) and class_name . endswith ( " ForCausalLM " ) :
# We support non-hf version of llava models, so we do not want to
# read the wrong values from the unused default text_config.
2024-12-02 23:22:13 +08:00
# NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
# `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
setattr ( config , " torch_dtype " , torch . float16 )
2024-07-06 00:58:46 -07:00
return config
2024-06-12 07:39:52 +08:00
if hasattr ( config , " text_config " ) :
# The code operates under the assumption that text_config should have
# `num_attention_heads` (among others). Assert here to fail early
# if transformers config doesn't align with this assumption.
assert hasattr ( config . text_config , " num_attention_heads " )
return config . text_config
else :
return config
2024-11-03 12:25:39 -08:00
2024-12-02 23:22:13 +08:00
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
_STR_DTYPE_TO_TORCH_DTYPE = {
" half " : torch . float16 ,
" float16 " : torch . float16 ,
" float " : torch . float32 ,
" float32 " : torch . float32 ,
" bfloat16 " : torch . bfloat16 ,
}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _get_and_verify_dtype (
config : PretrainedConfig ,
dtype : Union [ str , torch . dtype ] ,
) - > torch . dtype :
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None.
config_dtype = getattr ( config , " torch_dtype " , None )
if config_dtype is None :
config_dtype = torch . float32
if isinstance ( dtype , str ) :
dtype = dtype . lower ( )
if dtype == " auto " :
if config_dtype == torch . float32 :
if config . model_type == " gemma2 " :
logger . info (
" For Gemma 2, we downcast float32 to bfloat16 instead "
" of float16 by default. Please specify `dtype` if you "
" want to use float16. "
)
torch_dtype = torch . bfloat16
else :
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch . float16
else :
torch_dtype = config_dtype
else :
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE :
raise ValueError ( f " Unknown dtype: { dtype } " )
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE [ dtype ]
elif isinstance ( dtype , torch . dtype ) :
torch_dtype = dtype
else :
raise ValueError ( f " Unknown dtype: { dtype } " )
# Verify the dtype.
if torch_dtype != config_dtype :
if torch_dtype == torch . float32 :
# Upcasting to float32 is allowed.
logger . info ( " Upcasting %s to %s . " , config_dtype , torch_dtype )
pass
elif config_dtype == torch . float32 :
# Downcasting from float32 to float16 or bfloat16 is allowed.
logger . info ( " Downcasting %s to %s . " , config_dtype , torch_dtype )
pass
else :
# Casting between float16 and bfloat16 is allowed with a warning.
logger . warning ( " Casting %s to %s . " , config_dtype , torch_dtype )
return torch_dtype
2024-11-03 12:25:39 -08:00
def is_generation_model ( model_architectures : List [ str ] , is_embedding : bool = False ) :
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
# 2. check the `is_embedding` server args
if (
" LlamaEmbeddingModel " in model_architectures
or " MistralModel " in model_architectures
or " LlamaForSequenceClassification " in model_architectures
or " LlamaForSequenceClassificationWithNormal_Weights " in model_architectures
2024-11-12 07:09:58 +08:00
or " InternLM2ForRewardModel " in model_architectures
2024-11-03 12:25:39 -08:00
) :
return False
else :
return not is_embedding
def is_multimodal_model ( model_architectures : List [ str ] ) :
if (
" LlavaLlamaForCausalLM " in model_architectures
or " LlavaQwenForCausalLM " in model_architectures
or " LlavaMistralForCausalLM " in model_architectures
or " LlavaVidForCausalLM " in model_architectures
or " MllamaForConditionalGeneration " in model_architectures
or " Qwen2VLForConditionalGeneration " in model_architectures
) :
return True
else :
return False
def is_encoder_decoder_model ( model_architectures : List [ str ] ) :
return " MllamaForConditionalGeneration " in model_architectures