796 lines
30 KiB
Python
796 lines
30 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
# Adapted from
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
|
|
# Copyright 2025 The vLLM team.
|
|
# Copyright 2025 The ZhipuAI Team.
|
|
# Copyright 2025 The HuggingFace Inc. team.
|
|
# All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
|
|
|
import math
|
|
from collections.abc import Iterable, Mapping
|
|
from functools import partial
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch_br
|
|
from einops import rearrange, repeat
|
|
from torch_br.contrib import SueagerScaledDotProductAttention
|
|
|
|
import vllm
|
|
import vllm.model_executor.models.glm4
|
|
import vllm.model_executor.models.llama
|
|
import vllm.model_executor.models.qwen2_vl
|
|
import vllm_br.envs as br_envs
|
|
from vllm.attention.layer import check_upstream_fa_availability
|
|
from vllm.config import VllmConfig
|
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
|
parallel_state)
|
|
from vllm.distributed import utils as dist_utils
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
default_weight_loader, maybe_remap_kv_scale_name)
|
|
from vllm.model_executor.models.glm4_1v import (Glm4vForConditionalGeneration,
|
|
Glm4vVisionBlock,
|
|
Glm4vVisionMLP,
|
|
Glm4vVisionTransformer)
|
|
from vllm.model_executor.models.utils import (init_vllm_registered_model,
|
|
is_pp_missing_parameter,
|
|
maybe_prefix)
|
|
from vllm.model_executor.models.vision import get_vit_attn_backend
|
|
from vllm.platforms import _Backend, current_platform
|
|
from ..layers.activation import SiluAndMul
|
|
from ..layers.br_utils import is_br166_device
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
def Glm4vVisionMLP_init_fit(self,
|
|
in_features: int,
|
|
hidden_features: int,
|
|
bias: bool = False,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False):
|
|
super(Glm4vVisionMLP, self).__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
input_size=in_features,
|
|
output_sizes=[hidden_features] * 2,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj")
|
|
self.down_proj = RowParallelLinear(hidden_features,
|
|
in_features,
|
|
bias=bias,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.down_proj")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
|
|
def Glm4vVisionMLP_forward_fit(self, x: torch.Tensor):
|
|
x, _ = self.gate_up_proj(x)
|
|
|
|
#x = self.act_fn(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
|
"""All-gather the input tensor interleavely across model parallel group."""
|
|
import torch.distributed as dist
|
|
|
|
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
|
dist.all_gather(
|
|
gathered_tensors,
|
|
local_tensor,
|
|
group=parallel_state.get_tp_group().device_group,
|
|
)
|
|
|
|
gathered_tensors_split = [
|
|
torch.split(tensor, hidden_size // tp_size, -1)
|
|
for tensor in gathered_tensors
|
|
]
|
|
ordered_tensors = [
|
|
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
|
]
|
|
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
|
return result_tensor
|
|
|
|
|
|
class Glm4vVisionAttention_fit(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
embed_dim: int,
|
|
num_heads: int,
|
|
projection_size: int,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
) -> None:
|
|
super().__init__()
|
|
# Per attention head and per partition values.
|
|
self.tp_size = (1 if use_data_parallel else
|
|
get_tensor_model_parallel_world_size())
|
|
self.tp_rank = (0 if use_data_parallel else
|
|
parallel_state.get_tensor_model_parallel_rank())
|
|
self.hidden_size_per_attention_head = dist_utils.divide(
|
|
projection_size, num_heads)
|
|
self.num_attention_heads_per_partition = dist_utils.divide(
|
|
num_heads, self.tp_size)
|
|
|
|
#self.qkv = QKVParallelLinear(
|
|
# hidden_size=embed_dim,
|
|
# head_size=self.hidden_size_per_attention_head,
|
|
# total_num_heads=num_heads,
|
|
# total_num_kv_heads=num_heads,
|
|
# bias=False,
|
|
# quant_config=quant_config,
|
|
# prefix=f"{prefix}.qkv",
|
|
#)
|
|
#self.proj = RowParallelLinear(
|
|
# input_size=projection_size,
|
|
# output_size=embed_dim,
|
|
# quant_config=quant_config,
|
|
# prefix=f"{prefix}.proj",
|
|
# bias=False,
|
|
#)
|
|
qkv_output_size = (num_heads +
|
|
2 * num_heads) * self.hidden_size_per_attention_head
|
|
self.qkv = nn.Linear(embed_dim, qkv_output_size, bias=False)
|
|
self.proj = nn.Linear(projection_size, embed_dim, bias=False)
|
|
self.sueager_attention = SueagerScaledDotProductAttention()
|
|
|
|
# Detect attention implementation.
|
|
self.attn_backend = get_vit_attn_backend(
|
|
head_size=self.hidden_size_per_attention_head,
|
|
dtype=torch.get_default_dtype())
|
|
# self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
|
self.use_upstream_fa = False
|
|
if self.attn_backend != _Backend.FLASH_ATTN and \
|
|
check_upstream_fa_availability(torch.get_default_dtype()):
|
|
self.attn_backend = _Backend.FLASH_ATTN
|
|
self.use_upstream_fa = True
|
|
|
|
if self.attn_backend not in {
|
|
_Backend.FLASH_ATTN,
|
|
_Backend.TORCH_SDPA,
|
|
_Backend.XFORMERS,
|
|
}:
|
|
raise RuntimeError(
|
|
f"GLM-4V does not support {self.attn_backend} backend now.")
|
|
|
|
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
|
# [s, b, 3 * head * head_dim]
|
|
seq_len, bs, _ = qkv.shape
|
|
if self.tp_size > 1:
|
|
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
|
|
self.tp_size)
|
|
|
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
|
q, k, v = qkv.chunk(3, dim=2)
|
|
|
|
# 3 * [s, b, head * head_dim]
|
|
if self.tp_size > 1:
|
|
splitter = partial(
|
|
dist_utils.split_tensor_along_last_dim,
|
|
num_partitions=self.tp_size,
|
|
)
|
|
q = splitter(q)[self.tp_rank]
|
|
k = splitter(k)[self.tp_rank]
|
|
v = splitter(v)[self.tp_rank]
|
|
|
|
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
|
new_shape = (
|
|
seq_len,
|
|
bs,
|
|
self.num_attention_heads_per_partition,
|
|
self.hidden_size_per_attention_head,
|
|
)
|
|
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
|
return q, k, v
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
rotary_pos_emb: torch.Tensor,
|
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
|
) -> torch.Tensor:
|
|
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
|
# x, _ = self.qkv(x)
|
|
x = self.qkv(x)
|
|
|
|
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
|
q, k, v = self.split_qkv(x)
|
|
batch_size = q.shape[1]
|
|
|
|
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
|
for x in (q, k, v))
|
|
|
|
if rotary_pos_emb is not None:
|
|
q = glm_apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
|
k = glm_apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
|
|
|
if self.attn_backend == _Backend.FLASH_ATTN:
|
|
from flash_attn import flash_attn_varlen_func
|
|
|
|
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
|
|
|
output = flash_attn_varlen_func(
|
|
q,
|
|
k,
|
|
v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=max_seqlen,
|
|
max_seqlen_k=max_seqlen,
|
|
dropout_p=0,
|
|
causal=False,
|
|
)
|
|
|
|
context_layer = rearrange(output,
|
|
"(b s) ... -> b s ...",
|
|
b=batch_size)
|
|
elif self.attn_backend == _Backend.TORCH_SDPA:
|
|
# Execute attention entry by entry for speed & less VRAM.
|
|
outputs = []
|
|
|
|
for i in range(1, len(cu_seqlens)):
|
|
start_idx = cu_seqlens[i - 1]
|
|
end_idx = cu_seqlens[i]
|
|
q_i = q[:, start_idx:end_idx]
|
|
k_i = k[:, start_idx:end_idx]
|
|
v_i = v[:, start_idx:end_idx]
|
|
q_i, k_i, v_i = (rearrange(x, "b s h d -> s h b d")
|
|
for x in [q_i, k_i, v_i])
|
|
output_i = torch_br.sueager_scaled_dot_product_attention_fwd(
|
|
q_i.squeeze(),
|
|
k_i.squeeze(),
|
|
v_i.squeeze(),
|
|
mask=None,
|
|
dropout_prob=0.0,
|
|
is_causal=False,
|
|
scale=1 / math.sqrt(q_i.shape[-1]),
|
|
algorithm="FMHA",
|
|
)[0]
|
|
output_i = output_i.unsqueeze(0)
|
|
if is_br166_device():
|
|
output_tmp = torch_br._empty_ut_only(output_i.shape,
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="BB",
|
|
axis=0,
|
|
dtype=torch.bfloat16)
|
|
output_tmp.copy_(output_i)
|
|
output_i = output_tmp
|
|
output_i = rearrange(output_i, "b s h d -> h b s d")
|
|
outputs.append(output_i)
|
|
context_layer = torch.cat(outputs, dim=1)
|
|
elif self.attn_backend == _Backend.XFORMERS:
|
|
from xformers import ops as xops
|
|
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
|
|
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
|
kv_seqlen=None,
|
|
device=q.device)
|
|
|
|
context_layer = xops.memory_efficient_attention_forward(
|
|
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
|
|
|
context_layer = rearrange(context_layer,
|
|
"b s h d -> s b (h d)").contiguous()
|
|
|
|
# output, _ = self.proj(context_layer)
|
|
output = self.proj(context_layer)
|
|
return output
|
|
|
|
|
|
def Glm4vVisionBlock_init_fit(
|
|
self,
|
|
dim: int,
|
|
num_heads: int,
|
|
mlp_hidden_dim: int,
|
|
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
use_data_parallel: bool = False,
|
|
) -> None:
|
|
super(Glm4vVisionBlock, self).__init__()
|
|
if norm_layer is None:
|
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
|
self.norm1 = norm_layer(dim)
|
|
self.norm2 = norm_layer(dim)
|
|
self.attn = Glm4vVisionAttention_fit(
|
|
embed_dim=dim,
|
|
num_heads=num_heads,
|
|
projection_size=dim,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
self.mlp = Glm4vVisionMLP(
|
|
dim,
|
|
mlp_hidden_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
use_data_parallel=use_data_parallel,
|
|
)
|
|
|
|
|
|
def Glm4vVisionBlock_forward_fit(
|
|
self,
|
|
x: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
rotary_pos_emb: torch.Tensor,
|
|
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
|
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
|
) -> torch.Tensor:
|
|
#from fpdb import ForkedPdb
|
|
|
|
normx = self.norm1(x)
|
|
cur_device = torch.supa.current_device()
|
|
x = x + self.attn(
|
|
normx,
|
|
cu_seqlens=cu_seqlens,
|
|
rotary_pos_emb=rotary_pos_emb.to(cur_device),
|
|
max_seqlen=max_seqlen,
|
|
seqlens=seqlens,
|
|
)
|
|
x = x + self.mlp(self.norm2(x))
|
|
return x
|
|
|
|
|
|
def Llama_load_weights(
|
|
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
(".qkv_proj", ".q_proj", "q"),
|
|
(".qkv_proj", ".k_proj", "k"),
|
|
(".qkv_proj", ".v_proj", "v"),
|
|
(".gate_up_proj", ".gate_proj", 0),
|
|
(".gate_up_proj", ".up_proj", 1),
|
|
]
|
|
|
|
split_params_mapping = [
|
|
(".gate_up_proj", ".gate_proj", ".up_proj"),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
if ("rotary_emb.cos_cached" in name
|
|
or "rotary_emb.sin_cached" in name):
|
|
# Models trained using ColossalAI may include these tensors in
|
|
# the checkpoint. Skip them.
|
|
continue
|
|
if (self.quant_config is not None
|
|
and (scale_name := self.quant_config.get_cache_scale(name))):
|
|
# Loading kv cache quantization scales
|
|
param = params_dict[scale_name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
loaded_weight = (loaded_weight
|
|
if loaded_weight.dim() == 0 else loaded_weight[0])
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(scale_name)
|
|
continue
|
|
if "scale" in name:
|
|
# Remapping the name of FP8 kv-scale.
|
|
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
if name is None:
|
|
continue
|
|
|
|
do_mapping_flag = False
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
do_mapping_flag = True
|
|
loaded_params.add(name)
|
|
break
|
|
|
|
if not do_mapping_flag:
|
|
for gate_up, gate, up in split_params_mapping:
|
|
if gate_up not in name:
|
|
continue
|
|
gate_name = name.replace(gate_up, gate)
|
|
up_name = name.replace(gate_up, up)
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param_gate = params_dict[gate_name]
|
|
param_up = params_dict[up_name]
|
|
assert loaded_weight.shape[0] == param_gate.shape[
|
|
0] + param_up.shape[0], "gate up shape is not match"
|
|
|
|
weight_loader_gate = param_gate.weight_loader
|
|
weight_loader_gate(param_gate, loaded_weight[
|
|
:param_gate.shape[0],
|
|
])
|
|
|
|
weight_loader_up = param_up.weight_loader
|
|
weight_loader_up(param_up, loaded_weight[
|
|
param_gate.shape[0]:,
|
|
])
|
|
|
|
do_mapping_flag = True
|
|
loaded_params.add(gate_name)
|
|
loaded_params.add(up_name)
|
|
break
|
|
|
|
if not do_mapping_flag:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
|
if not interleaved:
|
|
x1, x2 = x.chunk(2, dim=-1)
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
else:
|
|
x1, x2 = x[..., ::2], x[..., 1::2]
|
|
return rearrange(torch.stack((-x2, x1), dim=-1),
|
|
"... d two -> ... (d two)",
|
|
two=2)
|
|
|
|
|
|
def glm_apply_rotary_emb_torch(x: torch.Tensor,
|
|
cos: torch.Tensor,
|
|
sin: torch.Tensor,
|
|
interleaved: bool = False) -> torch.Tensor:
|
|
"""
|
|
x: (batch_size, seqlen, nheads, headdim)
|
|
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
|
"""
|
|
ro_dim = cos.shape[-1] * 2
|
|
assert ro_dim <= x.shape[-1]
|
|
cos = repeat(
|
|
cos,
|
|
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
|
sin = repeat(
|
|
sin,
|
|
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
|
cos = cos.unsqueeze(2)
|
|
sin = sin.unsqueeze(2)
|
|
|
|
res = torch.cat(
|
|
[
|
|
x[..., :ro_dim] * cos +
|
|
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
|
|
],
|
|
dim=-1,
|
|
)
|
|
return res
|
|
|
|
|
|
def glm_apply_rotary_pos_emb_vision(t: torch.Tensor,
|
|
freqs: torch.Tensor) -> torch.Tensor:
|
|
t_ = t.float()
|
|
cos = freqs.cos()
|
|
sin = freqs.sin()
|
|
apply_rotary_emb = glm_apply_rotary_emb_torch
|
|
if current_platform.is_cuda():
|
|
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
|
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
|
return output
|
|
|
|
|
|
def LlamaMLP_glm4_1v_forward(self, x):
|
|
x, _ = self.gate_up_proj(x)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def Glm4Attention_forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
):
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
if is_br166_device():
|
|
q_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.q_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="SB",
|
|
axis=2,
|
|
dtype=torch.bfloat16)
|
|
k_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="SB",
|
|
axis=2,
|
|
dtype=torch.bfloat16)
|
|
|
|
q_tmp.copy_(q)
|
|
k_tmp.copy_(k)
|
|
q = q_tmp
|
|
k = k_tmp
|
|
q_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.q_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="BB",
|
|
axis=0,
|
|
dtype=torch.bfloat16)
|
|
k_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="BB",
|
|
axis=0,
|
|
dtype=torch.bfloat16)
|
|
q_tmp.copy_(q)
|
|
k_tmp.copy_(k)
|
|
q = q_tmp
|
|
k = k_tmp
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
if is_br166_device():
|
|
q_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.q_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="SB",
|
|
axis=2,
|
|
dtype=torch.bfloat16)
|
|
k_tmp = torch_br._empty_ut_only(
|
|
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="SB",
|
|
axis=2,
|
|
dtype=torch.bfloat16)
|
|
q_tmp.copy_(q)
|
|
k_tmp.copy_(k)
|
|
q = q_tmp
|
|
k = k_tmp
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.o_proj(attn_output)
|
|
return output
|
|
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int]:
|
|
max_image_tokens = self.get_max_image_tokens()
|
|
target_width, target_height = self.get_image_size_with_most_features()
|
|
max_video_tokens = self.get_num_video_tokens(image_width=target_width,
|
|
image_height=target_height,
|
|
num_frames=1)
|
|
return {"image": max_image_tokens, "video": max_video_tokens}
|
|
|
|
|
|
def glm4v_init(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super(Glm4vForConditionalGeneration, self).__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
|
|
|
self.visual = Glm4vVisionTransformer(
|
|
config.vision_config,
|
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "visual"),
|
|
use_data_parallel=self.use_data_parallel,
|
|
)
|
|
|
|
if config.model_type == "glm4v":
|
|
architectures = ["Glm4ForCausalLM"]
|
|
elif config.model_type == "glm4v_moe":
|
|
architectures = ["Glm4MoeForCausalLM"]
|
|
else:
|
|
architectures = None
|
|
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
architectures=architectures)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors)
|
|
|
|
br_envs.VLLM_BR_USE_MROPE_0_9_2 = True
|
|
|
|
|
|
def Glm4vPatchMerger_forward(self, x: torch.Tensor):
|
|
x, _ = self.proj(x)
|
|
if is_br166_device():
|
|
output_tmp = torch_br._empty_ut_only(x.shape,
|
|
"COLMAJOR",
|
|
is_numa=False,
|
|
sbp="BB",
|
|
axis=0,
|
|
dtype=torch.bfloat16)
|
|
output_tmp.copy_(x)
|
|
x = output_tmp
|
|
x = self.extra_activation_func(self.post_projection_norm(x))
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
# x = self.act_fn(gate_up)
|
|
x = gate_up
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
def Glm4vVisionEmbeddings_forward(self, embeddings, lengths, image_shapes,
|
|
h_coords, w_coords) -> torch.Tensor:
|
|
pos_embed_weight = self.position_embedding.weight
|
|
hidden_size = pos_embed_weight.shape[1]
|
|
total_seq = h_coords.shape[0]
|
|
device = pos_embed_weight.device
|
|
|
|
# Move coordinates to correct device
|
|
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
|
|
|
# Handle empty sequence case
|
|
if total_seq == 0:
|
|
adapted_pos_embed = torch.empty(0,
|
|
hidden_size,
|
|
device=device,
|
|
dtype=pos_embed_weight.dtype)
|
|
else:
|
|
# Convert inputs to tensors if needed
|
|
if isinstance(lengths, list):
|
|
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
|
if not isinstance(image_shapes, torch.Tensor):
|
|
image_shapes = torch.tensor(image_shapes,
|
|
device=device,
|
|
dtype=torch.long)
|
|
|
|
# Prepare 2D position embedding
|
|
orig_size_sq = pos_embed_weight.shape[0]
|
|
orig_size = int(orig_size_sq**0.5)
|
|
pos_embed_2d = (pos_embed_weight.view(orig_size,
|
|
orig_size, hidden_size).permute(
|
|
2, 0, 1).unsqueeze(0))
|
|
pos_embed_2d = pos_embed_2d.to(torch.float32)
|
|
|
|
# Calculate target dimensions for each patch
|
|
# Add bounds checking for data parallel mode
|
|
if len(lengths) > image_shapes.shape[0]:
|
|
# In data parallel mode, some GPUs might not have all
|
|
# image shapes
|
|
# Use available image shapes, cycling if necessary
|
|
target_h_list = []
|
|
target_w_list = []
|
|
for i in range(len(lengths)):
|
|
# Cycle through available shapes
|
|
shape_idx = i % image_shapes.shape[0]
|
|
target_h_list.append(image_shapes[shape_idx,
|
|
1].repeat(lengths[i]))
|
|
target_w_list.append(image_shapes[shape_idx,
|
|
2].repeat(lengths[i]))
|
|
target_h = torch.cat(target_h_list).to(device=device,
|
|
dtype=torch.float32)
|
|
target_w = torch.cat(target_w_list).to(device=device,
|
|
dtype=torch.float32)
|
|
else:
|
|
target_h = torch.cat([
|
|
image_shapes[i, 1].repeat(lengths[i])
|
|
for i in range(len(lengths))
|
|
]).to(device=device, dtype=torch.float32)
|
|
target_w = torch.cat([
|
|
image_shapes[i, 2].repeat(lengths[i])
|
|
for i in range(len(lengths))
|
|
]).to(device=device, dtype=torch.float32)
|
|
|
|
# Normalize coordinates to [-1, 1] range for grid_sample
|
|
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
|
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
|
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
|
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
|
|
|
# Create sampling grid
|
|
grid = (torch.stack((norm_w, norm_h),
|
|
dim=-1).unsqueeze(0).unsqueeze(2))
|
|
|
|
# Perform bicubic interpolation
|
|
interpolated_embed_fp32 = F.grid_sample(
|
|
pos_embed_2d,
|
|
grid,
|
|
mode="bicubic",
|
|
align_corners=False,
|
|
padding_mode="border",
|
|
)
|
|
|
|
# Reshape and convert back to original dtype
|
|
adapted_pos_embed_fp32 = (
|
|
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
|
|
adapted_pos_embed = adapted_pos_embed_fp32.to(
|
|
pos_embed_weight.dtype).to(embeddings.device)
|
|
|
|
# Add adapted position encoding to embeddings
|
|
embeddings = embeddings + adapted_pos_embed
|
|
return embeddings
|
|
|
|
|
|
#LlamaModel.load_weights = Llama_load_weights
|
|
vllm.model_executor.models.llama.LlamaMLP.forward = LlamaMLP_glm4_1v_forward
|
|
vllm.model_executor.models.glm4.Glm4Attention.forward = Glm4Attention_forward
|
|
#vllm.model_executor.models.glm4_1v.Glm4vVisionAttention = Glm4vVisionAttention_fit
|
|
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.__init__ = Glm4vVisionBlock_init_fit
|
|
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.forward = Glm4vVisionBlock_forward_fit
|
|
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.forward = Glm4vVisionMLP_forward_fit
|
|
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.__init__ = Glm4vVisionMLP_init_fit
|
|
vllm.model_executor.models.glm4_1v.Glm4vProcessingInfo.get_mm_max_tokens_per_item = get_mm_max_tokens_per_item
|
|
vllm.model_executor.models.glm4_1v.Glm4vForConditionalGeneration.__init__ = glm4v_init
|
|
vllm.model_executor.models.glm4_1v.Glm4vPatchMerger.forward = Glm4vPatchMerger_forward
|
|
vllm.model_executor.models.glm4_1v.Glm4vVisionEmbeddings.forward = Glm4vVisionEmbeddings_forward
|