################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ # SPDX-License-Identifier: Apache-2.0 # Copyright 2025 The Zhipu AI team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GLM-4-0414 model compatible with HuggingFace weights.""" #from typing import Any, Optional # #import torch #from fastcore.basics import patch_to #from transformers import Glm4Config # #import vllm #from vllm.attention import Attention, AttentionType #from vllm.config import CacheConfig #from vllm.distributed import get_tensor_model_parallel_world_size #from vllm.model_executor.layers.linear import (QKVParallelLinear, # RowParallelLinear) #from vllm.model_executor.layers.quantization import QuantizationConfig #from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding, # RotaryEmbedding) #from vllm.model_executor.models.glm4 import Glm4Attention # #_ROPE_DICT: dict[tuple, RotaryEmbedding] = {} #def get_rope_0_9_2( # head_size: int, # rotary_dim: int, # max_position: int, # base: float, # is_neox_style: bool = True, # rope_scaling: Optional[dict[str, Any]] = None, # dtype: Optional[torch.dtype] = None, # partial_rotary_factor: float = 1.0, # dual_chunk_attention_config: Optional[dict[str, Any]] = None, #) -> RotaryEmbedding: # # if dtype is None: # dtype = torch.get_default_dtype() # if rope_scaling is not None: # # Transforms every value that is a list into a tuple for caching calls # rope_scaling_tuple = { # k: tuple(v) if isinstance(v, list) else v # for k, v in rope_scaling.items() # } # rope_scaling_args = tuple(rope_scaling_tuple.items()) # else: # rope_scaling_args = None # # if dual_chunk_attention_config is not None: # dual_chunk_attention_tuple = { # k: tuple(v) if isinstance(v, list) else v # for k, v in dual_chunk_attention_config.items() # if k != "sparse_attention_config" # } # dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items()) # else: # dual_chunk_attention_args = None # # if partial_rotary_factor < 1.0: # rotary_dim = int(rotary_dim * partial_rotary_factor) # key = (head_size, rotary_dim, max_position, base, is_neox_style, # rope_scaling_args, dual_chunk_attention_args, dtype) # if key in _ROPE_DICT: # return _ROPE_DICT[key] # # if dual_chunk_attention_config is not None: # extra_kwargs = { # k: v # for k, v in dual_chunk_attention_config.items() # if k in ("chunk_size", "local_size") # } # rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim, # max_position, base, # is_neox_style, dtype, # **extra_kwargs) # elif not rope_scaling: # rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, # is_neox_style, dtype) # else: # scaling_type = rope_scaling["rope_type"] # # if scaling_type == "llama3": # scaling_factor = rope_scaling["factor"] # low_freq_factor = rope_scaling["low_freq_factor"] # high_freq_factor = rope_scaling["high_freq_factor"] # original_max_position = rope_scaling[ # "original_max_position_embeddings"] # rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, # max_position, base, # is_neox_style, dtype, # scaling_factor, low_freq_factor, # high_freq_factor, # original_max_position) # elif scaling_type == "mllama4": # rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim, # max_position, base, # is_neox_style, dtype) # elif scaling_type == "default": # if "mrope_section" in rope_scaling: # rotary_emb = MRotaryEmbedding( # head_size, # rotary_dim, # max_position, # base, # is_neox_style, # dtype, # mrope_section=rope_scaling["mrope_section"], # ) # else: # rotary_emb = RotaryEmbedding( # head_size, # rotary_dim, # max_position, # base, # is_neox_style, # dtype, # ) # elif scaling_type == "linear": # scaling_factor = rope_scaling["factor"] # rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, # max_position, base, # is_neox_style, # scaling_factor, dtype) # elif scaling_type == "ntk": # scaling_factor = rope_scaling["factor"] # mixed_b = rope_scaling.get('mixed_b', None) # rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, # max_position, base, # is_neox_style, # scaling_factor, dtype, # mixed_b) # elif scaling_type == "dynamic": # if "alpha" in rope_scaling: # scaling_alpha = rope_scaling["alpha"] # rotary_emb = DynamicNTKAlphaRotaryEmbedding( # head_size, rotary_dim, max_position, base, is_neox_style, # scaling_alpha, dtype) # elif "factor" in rope_scaling: # scaling_factor = rope_scaling["factor"] # rotary_emb = DynamicNTKScalingRotaryEmbedding( # head_size, rotary_dim, max_position, base, is_neox_style, # scaling_factor, dtype) # else: # raise ValueError("Dynamic rope scaling must contain either " # "'alpha' or 'factor' field") # elif scaling_type == "yarn": # scaling_factor = rope_scaling["factor"] # original_max_position = rope_scaling[ # "original_max_position_embeddings"] # extra_kwargs = { # k: v # for k, v in rope_scaling.items() # if k in ("extrapolation_factor", "attn_factor", "beta_fast", # "beta_slow") # } # rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, # original_max_position, # base, is_neox_style, # scaling_factor, dtype, # **extra_kwargs) # elif scaling_type == "deepseek_yarn": # scaling_factor = rope_scaling["factor"] # original_max_position = rope_scaling[ # "original_max_position_embeddings"] # # assert max_position == original_max_position * scaling_factor # extra_kwargs = { # k: v # for k, v in rope_scaling.items() # if k in ("extrapolation_factor", "attn_factor", "beta_fast", # "beta_slow", "mscale", "mscale_all_dim") # } # rotary_emb = DeepseekScalingRotaryEmbedding( # head_size, rotary_dim, original_max_position, base, # is_neox_style, scaling_factor, dtype, **extra_kwargs) # elif scaling_type == "longrope": # short_factor = rope_scaling["short_factor"] # long_factor = rope_scaling["long_factor"] # original_max_position = rope_scaling[ # "original_max_position_embeddings"] # extra_kwargs = { # k: v # for k, v in rope_scaling.items() # if k in ("short_mscale", "long_mscale") # } # rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( # head_size, rotary_dim, max_position, original_max_position, # base, is_neox_style, dtype, short_factor, long_factor, # **extra_kwargs) # else: # raise ValueError(f"Unknown RoPE scaling type {scaling_type}") # _ROPE_DICT[key] = rotary_emb # return rotary_emb # # #@patch_to(vllm.model_executor.models.glm4.Glm4Attention) #def __init__(self, # config: Glm4Config, # hidden_size: int, # num_heads: int, # num_kv_heads: int, # max_position: int = 4096 * 32, # head_dim: Optional[int] = None, # qkv_bias: bool = False, # rope_theta: float = 10000, # cache_config: Optional[CacheConfig] = None, # quant_config: Optional[QuantizationConfig] = None, # rope_scaling: Optional[tuple] = None, # prefix: str = "", # attn_type: str = AttentionType.DECODER) -> None: # super(Glm4Attention, self).__init__() # self.hidden_size = hidden_size # tp_size = get_tensor_model_parallel_world_size() # self.total_num_heads = num_heads # assert self.total_num_heads % tp_size == 0 # self.num_heads = self.total_num_heads // tp_size # self.total_num_kv_heads = num_kv_heads # if self.total_num_kv_heads >= tp_size: # # Number of KV heads is greater than TP size, so we partition # # the KV heads across multiple tensor parallel GPUs. # assert self.total_num_kv_heads % tp_size == 0 # else: # # Number of KV heads is less than TP size, so we replicate # # the KV heads across multiple tensor parallel GPUs. # assert tp_size % self.total_num_kv_heads == 0 # partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5) # self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) # self.head_dim = head_dim or hidden_size // self.total_num_heads # self.rotary_dim = self.head_dim # self.q_size = self.num_heads * self.head_dim # self.kv_size = self.num_kv_heads * self.head_dim # self.scaling = self.head_dim**-0.5 # self.rope_theta = rope_theta # self.qkv_proj = QKVParallelLinear( # hidden_size, # self.head_dim, # self.total_num_heads, # self.total_num_kv_heads, # bias=qkv_bias, # quant_config=quant_config, # prefix=f"{prefix}.qkv_proj", # ) # self.o_proj = RowParallelLinear( # self.total_num_heads * self.head_dim, # hidden_size, # bias=False, # quant_config=quant_config, # prefix=f"{prefix}.o_proj", # ) # self.rotary_emb = get_rope_0_9_2( # self.head_dim, # rotary_dim=self.rotary_dim, # max_position=max_position, # base=self.rope_theta, # rope_scaling=rope_scaling, # partial_rotary_factor=partial_rotary_factor, # is_neox_style=False, # ) # self.attn = Attention(self.num_heads, # self.head_dim, # self.scaling, # num_kv_heads=self.num_kv_heads, # cache_config=cache_config, # quant_config=quant_config, # prefix=f"{prefix}.attn", # attn_type=attn_type)