Files
2026-03-10 13:31:25 +08:00

300 lines
13 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Zhipu AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
#from typing import Any, Optional
#
#import torch
#from fastcore.basics import patch_to
#from transformers import Glm4Config
#
#import vllm
#from vllm.attention import Attention, AttentionType
#from vllm.config import CacheConfig
#from vllm.distributed import get_tensor_model_parallel_world_size
#from vllm.model_executor.layers.linear import (QKVParallelLinear,
# RowParallelLinear)
#from vllm.model_executor.layers.quantization import QuantizationConfig
#from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding,
# RotaryEmbedding)
#from vllm.model_executor.models.glm4 import Glm4Attention
#
#_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
#def get_rope_0_9_2(
# head_size: int,
# rotary_dim: int,
# max_position: int,
# base: float,
# is_neox_style: bool = True,
# rope_scaling: Optional[dict[str, Any]] = None,
# dtype: Optional[torch.dtype] = None,
# partial_rotary_factor: float = 1.0,
# dual_chunk_attention_config: Optional[dict[str, Any]] = None,
#) -> RotaryEmbedding:
#
# if dtype is None:
# dtype = torch.get_default_dtype()
# if rope_scaling is not None:
# # Transforms every value that is a list into a tuple for caching calls
# rope_scaling_tuple = {
# k: tuple(v) if isinstance(v, list) else v
# for k, v in rope_scaling.items()
# }
# rope_scaling_args = tuple(rope_scaling_tuple.items())
# else:
# rope_scaling_args = None
#
# if dual_chunk_attention_config is not None:
# dual_chunk_attention_tuple = {
# k: tuple(v) if isinstance(v, list) else v
# for k, v in dual_chunk_attention_config.items()
# if k != "sparse_attention_config"
# }
# dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
# else:
# dual_chunk_attention_args = None
#
# if partial_rotary_factor < 1.0:
# rotary_dim = int(rotary_dim * partial_rotary_factor)
# key = (head_size, rotary_dim, max_position, base, is_neox_style,
# rope_scaling_args, dual_chunk_attention_args, dtype)
# if key in _ROPE_DICT:
# return _ROPE_DICT[key]
#
# if dual_chunk_attention_config is not None:
# extra_kwargs = {
# k: v
# for k, v in dual_chunk_attention_config.items()
# if k in ("chunk_size", "local_size")
# }
# rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype,
# **extra_kwargs)
# elif not rope_scaling:
# rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
# is_neox_style, dtype)
# else:
# scaling_type = rope_scaling["rope_type"]
#
# if scaling_type == "llama3":
# scaling_factor = rope_scaling["factor"]
# low_freq_factor = rope_scaling["low_freq_factor"]
# high_freq_factor = rope_scaling["high_freq_factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype,
# scaling_factor, low_freq_factor,
# high_freq_factor,
# original_max_position)
# elif scaling_type == "mllama4":
# rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype)
# elif scaling_type == "default":
# if "mrope_section" in rope_scaling:
# rotary_emb = MRotaryEmbedding(
# head_size,
# rotary_dim,
# max_position,
# base,
# is_neox_style,
# dtype,
# mrope_section=rope_scaling["mrope_section"],
# )
# else:
# rotary_emb = RotaryEmbedding(
# head_size,
# rotary_dim,
# max_position,
# base,
# is_neox_style,
# dtype,
# )
# elif scaling_type == "linear":
# scaling_factor = rope_scaling["factor"]
# rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style,
# scaling_factor, dtype)
# elif scaling_type == "ntk":
# scaling_factor = rope_scaling["factor"]
# mixed_b = rope_scaling.get('mixed_b', None)
# rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style,
# scaling_factor, dtype,
# mixed_b)
# elif scaling_type == "dynamic":
# if "alpha" in rope_scaling:
# scaling_alpha = rope_scaling["alpha"]
# rotary_emb = DynamicNTKAlphaRotaryEmbedding(
# head_size, rotary_dim, max_position, base, is_neox_style,
# scaling_alpha, dtype)
# elif "factor" in rope_scaling:
# scaling_factor = rope_scaling["factor"]
# rotary_emb = DynamicNTKScalingRotaryEmbedding(
# head_size, rotary_dim, max_position, base, is_neox_style,
# scaling_factor, dtype)
# else:
# raise ValueError("Dynamic rope scaling must contain either "
# "'alpha' or 'factor' field")
# elif scaling_type == "yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow")
# }
# rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
# original_max_position,
# base, is_neox_style,
# scaling_factor, dtype,
# **extra_kwargs)
# elif scaling_type == "deepseek_yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# # assert max_position == original_max_position * scaling_factor
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow", "mscale", "mscale_all_dim")
# }
# rotary_emb = DeepseekScalingRotaryEmbedding(
# head_size, rotary_dim, original_max_position, base,
# is_neox_style, scaling_factor, dtype, **extra_kwargs)
# elif scaling_type == "longrope":
# short_factor = rope_scaling["short_factor"]
# long_factor = rope_scaling["long_factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("short_mscale", "long_mscale")
# }
# rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
# head_size, rotary_dim, max_position, original_max_position,
# base, is_neox_style, dtype, short_factor, long_factor,
# **extra_kwargs)
# else:
# raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
# _ROPE_DICT[key] = rotary_emb
# return rotary_emb
#
#
#@patch_to(vllm.model_executor.models.glm4.Glm4Attention)
#def __init__(self,
# config: Glm4Config,
# hidden_size: int,
# num_heads: int,
# num_kv_heads: int,
# max_position: int = 4096 * 32,
# head_dim: Optional[int] = None,
# qkv_bias: bool = False,
# rope_theta: float = 10000,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# rope_scaling: Optional[tuple] = None,
# prefix: str = "",
# attn_type: str = AttentionType.DECODER) -> None:
# super(Glm4Attention, self).__init__()
# self.hidden_size = hidden_size
# tp_size = get_tensor_model_parallel_world_size()
# self.total_num_heads = num_heads
# assert self.total_num_heads % tp_size == 0
# self.num_heads = self.total_num_heads // tp_size
# self.total_num_kv_heads = num_kv_heads
# if self.total_num_kv_heads >= tp_size:
# # Number of KV heads is greater than TP size, so we partition
# # the KV heads across multiple tensor parallel GPUs.
# assert self.total_num_kv_heads % tp_size == 0
# else:
# # Number of KV heads is less than TP size, so we replicate
# # the KV heads across multiple tensor parallel GPUs.
# assert tp_size % self.total_num_kv_heads == 0
# partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
# self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# self.head_dim = head_dim or hidden_size // self.total_num_heads
# self.rotary_dim = self.head_dim
# self.q_size = self.num_heads * self.head_dim
# self.kv_size = self.num_kv_heads * self.head_dim
# self.scaling = self.head_dim**-0.5
# self.rope_theta = rope_theta
# self.qkv_proj = QKVParallelLinear(
# hidden_size,
# self.head_dim,
# self.total_num_heads,
# self.total_num_kv_heads,
# bias=qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv_proj",
# )
# self.o_proj = RowParallelLinear(
# self.total_num_heads * self.head_dim,
# hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.o_proj",
# )
# self.rotary_emb = get_rope_0_9_2(
# self.head_dim,
# rotary_dim=self.rotary_dim,
# max_position=max_position,
# base=self.rope_theta,
# rope_scaling=rope_scaling,
# partial_rotary_factor=partial_rotary_factor,
# is_neox_style=False,
# )
# self.attn = Attention(self.num_heads,
# self.head_dim,
# self.scaling,
# num_kv_heads=self.num_kv_heads,
# cache_config=cache_config,
# quant_config=quant_config,
# prefix=f"{prefix}.attn",
# attn_type=attn_type)