Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -1,40 +1,57 @@
# coding=utf-8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
# https://github.com/zai-org/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Iterable, List, Optional, Tuple
import json
from collections.abc import Iterable
from itertools import islice
import torch
from torch import nn
from torch.nn import LayerNorm
from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (
AutoWeightsLoader,
WeightsMapper,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory,
make_layers,
maybe_prefix,
)
class GLMAttention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
config: ChatGLMConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
@@ -43,9 +60,11 @@ class GLMAttention(nn.Module):
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
self.total_num_kv_heads = (
config.multi_query_group_num
if config.multi_query_attention
else config.num_attention_heads
)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
@@ -67,48 +86,52 @@ class GLMAttention(nn.Module):
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
# https://huggingface.co/zai-org/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
rope_parameters = {
"rope_type": "default",
"rope_theta": 10000 * rope_ratio,
"partial_rotary_factor": 0.5,
}
# NOTE: zai-org/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
is_neox_style = not config.original_rope
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
rope_parameters=rope_parameters,
is_neox_style=is_neox_style,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(
q,
k,
v,
kv_cache,
attn_metadata,
)
context_layer = self.attn(q, k, v)
attn_output, _ = self.dense(context_layer)
return attn_output
@@ -123,8 +146,9 @@ class GLMMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
config: ChatGLMConfig,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -136,6 +160,7 @@ class GLMMLP(nn.Module):
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.activation_func = SiluAndMul()
@@ -146,6 +171,7 @@ class GLMMLP(nn.Module):
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def forward(self, hidden_states):
@@ -166,37 +192,42 @@ class GLMBlock(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
config: ChatGLMConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.apply_residual_connection_post_layernorm = (
config.apply_residual_connection_post_layernorm)
config.apply_residual_connection_post_layernorm
)
self.fp32_residual_connection = config.fp32_residual_connection
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Layernorm on the input data.
self.input_layernorm = layer_norm_func(config.hidden_size,
eps=config.layernorm_epsilon)
self.input_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon
)
# Self attention.
self.self_attention = GLMAttention(config, quant_config)
self.self_attention = GLMAttention(
config, cache_config, quant_config, prefix=f"{prefix}.self_attention"
)
self.hidden_dropout = config.hidden_dropout
# Layernorm on the attention output
self.post_attention_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
# MLP
self.mlp = GLMMLP(config, quant_config)
self.mlp = GLMMLP(config, quant_config, prefix=f"{prefix}.mlp")
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# hidden_states: [num_tokens, h]
# Layer norm at the beginning of the transformer layer.
@@ -205,8 +236,6 @@ class GLMBlock(nn.Module):
attention_output = self.self_attention(
hidden_states=layernorm_output,
position_ids=position_ids,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
)
# Residual connection.
@@ -236,8 +265,10 @@ class GLMTransformer(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
config: ChatGLMConfig,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.post_layer_norm = config.post_layer_norm
@@ -246,30 +277,36 @@ class GLMTransformer(nn.Module):
self.num_layers = config.num_layers
# Transformer layers.
self.layers = nn.ModuleList(
[GLMBlock(config, quant_config) for i in range(self.num_layers)])
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_layers,
lambda prefix: GLMBlock(config, cache_config, quant_config, prefix=prefix),
prefix=f"{prefix}.layers",
)
if self.post_layer_norm:
layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm
# Final layer norm before output.
self.final_layernorm = layer_norm_func(
config.hidden_size, eps=config.layernorm_epsilon)
config.hidden_size, eps=config.layernorm_epsilon
)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
["hidden_states"], config.hidden_size
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
for i in range(self.num_layers):
layer = self.layers[i]
) -> torch.Tensor | IntermediateTensors:
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(
hidden_states=hidden_states,
position_ids=position_ids,
kv_cache=kv_caches[i],
attn_metadata=attn_metadata,
hidden_states=hidden_states, position_ids=position_ids
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Final layer norm.
if self.post_layer_norm:
hidden_states = self.final_layernorm(hidden_states)
@@ -277,110 +314,189 @@ class GLMTransformer(nn.Module):
return hidden_states
class ChatGLMModel(nn.Module):
@support_torch_compile
class ChatGLMModel(nn.Module, SupportsQuant):
packed_modules_mapping = {
"linear_proj.merged_proj": [
"linear_proj.gate_proj",
"linear_proj.dense_h_to_4h",
]
}
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.embedding = VocabParallelEmbedding(config.padded_vocab_size,
config.hidden_size)
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.embedding = VocabParallelEmbedding(
config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embedding",
)
self.num_layers = config.num_layers
self.multi_query_group_num = config.multi_query_group_num
self.kv_channels = config.kv_channels
self.encoder = GLMTransformer(config, quant_config)
self.output_layer = ParallelLMHead(config.padded_vocab_size,
config.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
inputs_embeds = self.embedding(input_ids)
# Run encoder.
hidden_states = self.encoder(
hidden_states=inputs_embeds,
position_ids=position_ids,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
self.encoder = GLMTransformer(
config, cache_config, quant_config, prefix=f"{prefix}.encoder"
)
return hidden_states
self.output_layer = ParallelLMHead(
config.padded_vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.output_layer",
)
class ChatGLMForCausalLM(nn.Module):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"]
}
# LoRA specific attributes
supported_lora_modules = [
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
]
embedding_modules = {}
embedding_padding_modules = []
self.make_empty_intermediate_tensors = (
self.encoder.make_empty_intermediate_tensors
)
def __init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config: ChatGLMConfig = config
self.quant_config = quant_config
self.transformer = ChatGLMModel(config, quant_config)
self.lm_head_weight = self.transformer.output_layer.weight
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.sampler = Sampler()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embedding(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata)
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
# Run encoder.
hidden_states = self.encoder(
hidden_states=hidden_states,
position_ids=positions,
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("linear_proj.merged_proj", "linear_proj.gate_proj", 0),
("linear_proj.merged_proj", "linear_proj.dense_h_to_4h", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if "rotary_pos_emb.inv_freq" in name:
continue
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class ChatGLMBaseModel(nn.Module):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".word_embeddings": ""},
)
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
transformer_type: type[ChatGLMModel] = ChatGLMModel,
) -> None:
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.quant_config = quant_config
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
self.transformer = transformer_type(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "transformer")
)
if self.config.tie_word_embeddings:
self.transformer.output_layer.weight = self.transformer.embedding.weight
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config.padded_vocab_size)
self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.embed_input_ids(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "rotary_pos_emb.inv_freq" in name:
continue
if "word_embeddings" in name:
name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
class ChatGLMForCausalLM(ChatGLMBaseModel, SupportsLoRA, SupportsPP, SupportsQuant):
packed_modules_mapping = {
"query_key_value": ["query_key_value"],
"dense_h_to_4h": ["dense_h_to_4h"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
if hasattr(config, "vision_config"):
hf_overrides = {"architectures": ["GLM4VForCausalLM"]}
raise RuntimeError(
"The configuration of this model indicates that it supports "
"vision inputs, but you instantiated the text-only version "
"of this model. Please use the vision model by setting "
f"`--hf-overrides '{json.dumps(hf_overrides)}'`"
)
super().__init__(vllm_config=vllm_config, prefix=prefix)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.transformer(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states