201 lines
7.9 KiB
Python
201 lines
7.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from copy import deepcopy
|
|
from typing import TYPE_CHECKING
|
|
|
|
from vllm.logger import init_logger
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.config import VllmConfig
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class VerifyAndUpdateConfig:
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class GteNewModelConfig(VerifyAndUpdateConfig):
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
assert config.__class__.__name__ == "NewConfig"
|
|
assert config.hidden_act == "gelu"
|
|
|
|
config.hidden_act = "geglu"
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
config.rotary_kwargs = {
|
|
"head_size": head_dim,
|
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
"max_position": config.max_position_embeddings,
|
|
"base": config.rope_theta,
|
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
}
|
|
|
|
|
|
class JinaRobertaModelConfig(VerifyAndUpdateConfig):
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
if config.position_embedding_type == "rotary":
|
|
assert config.__class__.__name__ == "XLMRobertaFlashConfig"
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
config.rotary_kwargs = {
|
|
"head_size": head_dim,
|
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
"max_position": config.max_position_embeddings,
|
|
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
}
|
|
|
|
|
|
class NomicBertModelConfig(VerifyAndUpdateConfig):
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
assert config.__class__.__name__ == "NomicBertConfig"
|
|
assert config.activation_function in ["swiglu", "gelu"]
|
|
config.position_embedding_type = getattr(config,
|
|
"position_embedding_type",
|
|
"rope")
|
|
|
|
if config.activation_function == "swiglu":
|
|
config.hidden_act = "silu"
|
|
else:
|
|
config.hidden_act = config.activation_function
|
|
|
|
assert (config.mlp_fc1_bias == config.mlp_fc2_bias ==
|
|
config.qkv_proj_bias)
|
|
config.bias = config.qkv_proj_bias
|
|
|
|
assert config.rotary_emb_scale_base is None
|
|
assert not config.rotary_emb_interleaved
|
|
|
|
config.layer_norm_eps = config.layer_norm_epsilon
|
|
config.intermediate_size = config.n_inner
|
|
config.hidden_size = config.n_embd
|
|
config.num_hidden_layers = config.n_layer
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
rotary_emb_dim = head_dim * config.rotary_emb_fraction
|
|
max_trained_positions = getattr(config, "max_trained_positions", 2048)
|
|
config.rotary_kwargs = {
|
|
"head_size": head_dim,
|
|
"rotary_dim": rotary_emb_dim,
|
|
"max_position": max_trained_positions,
|
|
"base": getattr(config, "rope_theta", config.rotary_emb_base),
|
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
}
|
|
|
|
# we ignore config.rotary_scaling_factor so that for datasets shorter
|
|
# than max_trained_positions 2048, the results are consistent
|
|
# with SentenceTransformer.
|
|
# The context extension uses vllm style rope_theta and rope_scaling.
|
|
# See #17785 #18755
|
|
if (not vllm_config.model_config.hf_overrides
|
|
and vllm_config.model_config.original_max_model_len is None):
|
|
# Default
|
|
# Reset max_model_len to max_trained_positions.
|
|
# nomic-embed-text-v2-moe the length is set to 512
|
|
# by sentence_bert_config.json.
|
|
max_model_len_before = vllm_config.model_config.max_model_len
|
|
max_model_len = min(vllm_config.model_config.max_model_len,
|
|
max_trained_positions)
|
|
|
|
vllm_config.recalculate_max_model_len(max_model_len)
|
|
logger.warning(
|
|
"Nomic context extension is disabled. "
|
|
"Changing max_model_len from %s to %s. "
|
|
"To enable context extension, see: "
|
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
|
|
max_model_len_before, vllm_config.model_config.max_model_len)
|
|
else:
|
|
# We need to re-verify max_model_len to avoid lengths
|
|
# greater than position_embedding.
|
|
model_config = vllm_config.model_config
|
|
hf_text_config = model_config.hf_text_config
|
|
|
|
if isinstance(model_config.hf_overrides, dict):
|
|
# hf_overrides_kw
|
|
max_model_len = model_config.hf_overrides.get(
|
|
"max_model_len", vllm_config.model_config.max_model_len)
|
|
else:
|
|
# hf_overrides_fn
|
|
# This might be overridden by sentence_bert_config.json.
|
|
max_model_len = vllm_config.model_config.max_model_len
|
|
|
|
# reset hf_text_config for recalculate_max_model_len.
|
|
if hasattr(hf_text_config, "max_model_len"):
|
|
delattr(hf_text_config, "max_model_len")
|
|
hf_text_config.max_position_embeddings = max_trained_positions
|
|
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]
|
|
|
|
# The priority of sentence_bert_config.json is higher
|
|
# than max_position_embeddings
|
|
encoder_config = deepcopy(model_config.encoder_config)
|
|
encoder_config.pop("max_seq_length", None)
|
|
model_config.encoder_config = encoder_config
|
|
|
|
vllm_config.recalculate_max_model_len(max_model_len)
|
|
|
|
|
|
class Qwen3ForSequenceClassificationConfig(VerifyAndUpdateConfig):
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
is_original_qwen3_reranker = getattr(config,
|
|
"is_original_qwen3_reranker",
|
|
False)
|
|
|
|
if not is_original_qwen3_reranker:
|
|
return
|
|
|
|
tokens = getattr(config, "classifier_from_token", None)
|
|
assert tokens is not None and len(tokens) == 2, \
|
|
("Try loading the original Qwen3 Reranker?, see: "
|
|
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/qwen3_reranker.py")
|
|
vllm_config.model_config.hf_config.method = "from_2_way_softmax"
|
|
|
|
|
|
class SnowflakeGteNewModelConfig(VerifyAndUpdateConfig):
|
|
|
|
@staticmethod
|
|
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
assert config.__class__.__name__ == "GteConfig"
|
|
assert config.hidden_act == "gelu"
|
|
|
|
config.hidden_act = "geglu"
|
|
|
|
head_dim = config.hidden_size // config.num_attention_heads
|
|
config.rotary_kwargs = {
|
|
"head_size": head_dim,
|
|
"rotary_dim": getattr(config, "rotary_emb_dim", head_dim),
|
|
"max_position": config.max_position_embeddings,
|
|
"base": config.rope_theta,
|
|
"rope_scaling": getattr(config, "rope_scaling", None)
|
|
}
|
|
|
|
|
|
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
|
|
"GteModel": SnowflakeGteNewModelConfig,
|
|
"GteNewModel": GteNewModelConfig,
|
|
"NomicBertModel": NomicBertModelConfig,
|
|
"Qwen3ForSequenceClassification": Qwen3ForSequenceClassificationConfig,
|
|
"XLMRobertaModel": JinaRobertaModelConfig,
|
|
}
|