first commit
This commit is contained in:
46
vllm_br/model_executor/models/__init__.py
Normal file
46
vllm_br/model_executor/models/__init__.py
Normal file
@@ -0,0 +1,46 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import (bert, bert_with_rope, chatglm, clip, config, deepseek_mtp,
|
||||
deepseek_v2, glm4, glm4_1v, glm4_moe, gpt_oss, intern_vit,
|
||||
internlm2, llama, qwen2, qwen2_5_vl, qwen2_vl, qwen3, qwen3_moe,
|
||||
qwen3_vl, qwen3_vl_moe, registry, utils)
|
||||
|
||||
__all__ = [
|
||||
"bert_with_rope",
|
||||
"bert",
|
||||
"chatglm",
|
||||
"clip",
|
||||
"config",
|
||||
"deepseek_mtp",
|
||||
"deepseek_v2",
|
||||
"glm4_1v",
|
||||
"glm4_moe",
|
||||
"glm4",
|
||||
"gpt_oss",
|
||||
"intern_vit",
|
||||
"internlm2",
|
||||
"llama",
|
||||
"qwen2_5_vl",
|
||||
"qwen2_vl",
|
||||
"qwen2",
|
||||
"qwen3_moe",
|
||||
"qwen3",
|
||||
"qwen3_vl",
|
||||
"qwen3_vl_moe",
|
||||
"registry",
|
||||
"utils",
|
||||
]
|
||||
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/bert.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/bert.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/clip.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/clip.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/config.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/glm4.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/glm4.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/llama.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/llama.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/qwen2.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/qwen2.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/qwen3.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/qwen3.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm_br/model_executor/models/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm_br/model_executor/models/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
42
vllm_br/model_executor/models/bert.py
Normal file
42
vllm_br/model_executor/models/bert.py
Normal file
@@ -0,0 +1,42 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.models.bert import BertModel
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@patch_to(BertModel)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
input_ids = input_ids.unsqueeze(
|
||||
0
|
||||
) # Note: set input batch size (bs) to 1 here; otherwise attention module will raise an error.
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
position_ids=positions)
|
||||
hidden_states = self.encoder(hidden_states).squeeze(0)
|
||||
return hidden_states
|
||||
42
vllm_br/model_executor/models/bert_with_rope.py
Normal file
42
vllm_br/model_executor/models/bert_with_rope.py
Normal file
@@ -0,0 +1,42 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.models.bert_with_rope import BertWithRope
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
@patch_to(BertWithRope)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
if input_ids is None:
|
||||
raise ValueError("input_ids must be provided.")
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
hidden_states = self.embeddings(input_ids=input_ids,
|
||||
token_type_ids=token_type_ids)
|
||||
return self.encoder(positions, hidden_states).squeeze(0)
|
||||
48
vllm_br/model_executor/models/br_utils.py
Normal file
48
vllm_br/model_executor/models/br_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch_br
|
||||
|
||||
|
||||
def convBB(input_tensor):
|
||||
o_layout = torch_br.supa._debug.get_tensor_info(input_tensor)[0]["layout"]
|
||||
o_layout = o_layout.lower()
|
||||
input_tensor_supa = torch_br._empty_ut_only(
|
||||
size=input_tensor.shape,
|
||||
dtype=input_tensor.dtype,
|
||||
is_numa=False,
|
||||
device=input_tensor.device,
|
||||
tensor_type=o_layout,
|
||||
sbp="BB",
|
||||
)
|
||||
input_tensor_supa.copy_(input_tensor)
|
||||
return input_tensor_supa
|
||||
|
||||
|
||||
def convSB(input_tensor, axis: int):
|
||||
o_layout = torch_br.supa._debug.get_tensor_info(input_tensor)[0]["layout"]
|
||||
o_layout = o_layout.lower()
|
||||
input_tensor_supa = torch_br._empty_ut_only(
|
||||
size=input_tensor.shape,
|
||||
dtype=input_tensor.dtype,
|
||||
is_numa=False,
|
||||
device=input_tensor.device,
|
||||
tensor_type=o_layout,
|
||||
sbp="SB",
|
||||
axis=axis,
|
||||
)
|
||||
input_tensor_supa.copy_(input_tensor)
|
||||
return input_tensor_supa
|
||||
195
vllm_br/model_executor/models/chatglm.py
Normal file
195
vllm_br/model_executor/models/chatglm.py
Normal file
@@ -0,0 +1,195 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/THUDM/ChatGLM2-6B
|
||||
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.models.chatglm import GLMMLP
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
|
||||
|
||||
def model_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
# unsqueeze for RMSNorm op
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
# Run encoder.
|
||||
hidden_states = self.encoder(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
)
|
||||
# suqeeze to 2-d shape
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
class GLMAttention_fit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.multi_query_attention = config.multi_query_attention
|
||||
self.total_num_kv_heads = (config.multi_query_group_num
|
||||
if config.multi_query_attention else
|
||||
config.num_attention_heads)
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = config.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.query_key_value = QKVParallelLinear(
|
||||
self.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=config.add_bias_linear or config.add_qkv_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.query_key_value",
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
||||
rope_ratio = getattr(config, "rope_ratio", 1.0)
|
||||
max_positions = getattr(config, "seq_length", 8192)
|
||||
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
|
||||
# which is equivalent to is_neox_style=True
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim // 2,
|
||||
max_position=max_positions,
|
||||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
op_type="Chatglm2",
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.query_key_value(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(position_ids, q, k)
|
||||
context_layer = self.attn(q, k, v)
|
||||
attn_output, _ = self.dense(context_layer)
|
||||
return attn_output
|
||||
|
||||
|
||||
def GLMMLP__init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super(GLMMLP, self).__init__()
|
||||
|
||||
self.add_bias = config.add_bias_linear
|
||||
|
||||
# Project to 4h.
|
||||
self.dense_h_to_4h = MergedColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
[config.ffn_hidden_size] * 2,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_h_to_4h",
|
||||
)
|
||||
self.dense_h_to_4h.no_fuse_act = True
|
||||
self.activation_func = SiluAndMul()
|
||||
|
||||
# Project back to h.
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.ffn_hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.add_bias_linear,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.dense_4h_to_h",
|
||||
)
|
||||
|
||||
|
||||
def GLMMLP__forward(self, hidden_states):
|
||||
# [s, b, 4hp]
|
||||
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
|
||||
# [s, b, h]
|
||||
output, _ = self.dense_4h_to_h(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
vllm.model_executor.models.chatglm.GLMMLP.forward = GLMMLP__forward
|
||||
vllm.model_executor.models.chatglm.GLMMLP.__init__ = GLMMLP__init__
|
||||
vllm.model_executor.models.chatglm.ChatGLMModel.forward = model_forward
|
||||
vllm.model_executor.models.chatglm.GLMAttention = GLMAttention_fit
|
||||
65
vllm_br/model_executor/models/clip.py
Normal file
65
vllm_br/model_executor/models/clip.py
Normal file
@@ -0,0 +1,65 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
|
||||
from vllm.model_executor.models.clip import CLIPVisionEmbeddings
|
||||
|
||||
|
||||
def clip_vision_embeddings_forward(self,
|
||||
pixel_values: torch.Tensor) -> torch.Tensor:
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
if self.patch_size == 14:
|
||||
import torch_br.supa._debug as supa_debug
|
||||
|
||||
supa_debug.set_disable_zero_ws(False)
|
||||
supa_debug.set_disable_zero_output_uma(False)
|
||||
supa_debug.set_disable_zero_output_numa(False)
|
||||
supa_debug.set_disable_reorder_zero(False)
|
||||
|
||||
#TODO(shouqing): this op need to do internal clear_zeros operation
|
||||
patch_embeds = torch_br.supa_conv2d_knxn_snxn_p0x0_fwd(
|
||||
pixel_values.to(dtype=target_dtype), self.patch_embedding.weight,
|
||||
self.patch_size, self.patch_size, 0)
|
||||
|
||||
supa_debug.set_disable_zero_ws(True)
|
||||
supa_debug.set_disable_zero_output_uma(True)
|
||||
supa_debug.set_disable_zero_output_numa(True)
|
||||
supa_debug.set_disable_reorder_zero(True)
|
||||
else:
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
data_in_cpu = lambda t: t.device == torch.device('cpu')
|
||||
if data_in_cpu(self.position_ids):
|
||||
cur_device = torch.supa.current_device()
|
||||
self.position_ids = self.position_ids.to(cur_device)
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
#logger.debug('[Patch] patch CLIPVisionEmbeddings forward')
|
||||
CLIPVisionEmbeddings.forward = clip_vision_embeddings_forward
|
||||
51
vllm_br/model_executor/models/config.py
Normal file
51
vllm_br/model_executor/models/config.py
Normal file
@@ -0,0 +1,51 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.config import DeepseekV32ForCausalLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@patch_to(DeepseekV32ForCausalLM)
|
||||
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
|
||||
"""
|
||||
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
|
||||
"""
|
||||
hf_config = vllm_config.model_config.hf_config
|
||||
|
||||
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
|
||||
is_v32 = hasattr(hf_config, "index_topk")
|
||||
assert is_v32
|
||||
|
||||
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
|
||||
cache_config = vllm_config.cache_config
|
||||
if cache_config.cache_dtype.startswith("fp8"):
|
||||
cache_config.cache_dtype = "fp8_ds_mla"
|
||||
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
|
||||
if cache_config.cache_dtype == "bfloat16":
|
||||
cache_config.cache_dtype = "auto"
|
||||
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
|
||||
99
vllm_br/model_executor/models/deepseek_mtp.py
Normal file
99
vllm_br/model_executor/models/deepseek_mtp.py
Normal file
@@ -0,0 +1,99 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.models.deepseek_mtp import (
|
||||
DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead)
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
||||
|
||||
# from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
# from vllm.model_executor.layers.sampler import get_sampler
|
||||
|
||||
|
||||
@patch_to(DeepSeekMultiTokenPredictorLayer)
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
super(DeepSeekMultiTokenPredictorLayer, self).__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=False)
|
||||
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
if self.is_v32:
|
||||
topk_tokens = config.index_topk
|
||||
topk_indices_buffer = torch.empty(
|
||||
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
topk_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cuda")
|
||||
else:
|
||||
topk_indices_buffer = None
|
||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||
topk_indices_buffer)
|
||||
|
||||
|
||||
@patch_to(DeepSeekMultiTokenPredictorLayer)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
) -> torch.Tensor:
|
||||
assert inputs_embeds is not None
|
||||
# masking inputs at position 0, as not needed by MTP
|
||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
||||
torch.zeros_like(inputs_embeds), inputs_embeds)
|
||||
inputs_embeds = self.enorm(inputs_embeds.unsqueeze(0))
|
||||
previous_hidden_states = self.hnorm(previous_hidden_states.unsqueeze(0))
|
||||
|
||||
fused_hidden_states = torch.cat([inputs_embeds, previous_hidden_states],
|
||||
dim=-1)
|
||||
hidden_states = self.eh_proj(fused_hidden_states)
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
@patch_to(DeepSeekMultiTokenPredictor)
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
|
||||
logits = self.logits_processor(
|
||||
mtp_layer.shared_head.head,
|
||||
mtp_layer.shared_head(
|
||||
hidden_states.unsqueeze(0)).squeeze(0).contiguous())
|
||||
return logits
|
||||
924
vllm_br/model_executor/models/deepseek_v2.py
Normal file
924
vllm_br/model_executor/models/deepseek_v2.py
Normal file
@@ -0,0 +1,924 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
from torch import nn
|
||||
from transformers import DeepseekV2Config, DeepseekV3Config
|
||||
|
||||
import vllm
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
per_token_group_quant_fp8)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.models.deepseek_v2 import (
|
||||
DeepseekV2ForCausalLM, DeepseekV2Model, FusedMoE, Indexer, PPMissingLayer,
|
||||
default_weight_loader, get_spec_layer_idx_from_weight_name,
|
||||
is_pp_missing_parameter, maybe_prefix, maybe_remap_kv_scale_name,
|
||||
yarn_get_mscale)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
|
||||
from vllm_br.v1.attention.backends.mla.indexer import (
|
||||
SupaDeepseekV32IndexerBackend)
|
||||
from .supa_module import (DeepseekV2MoE, MergedGateUpMLPSiluL2, SupaMLAModules,
|
||||
SupaMultiHeadLatentAttention)
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.deepseek_v2.DeepseekV32IndexerCache)
|
||||
def get_attn_backend(self) -> AttentionBackend:
|
||||
return SupaDeepseekV32IndexerBackend
|
||||
|
||||
|
||||
class SupaDeepseekV2MLAAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.is_v32 = hasattr(config, "index_topk")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.fused_qkv_a_proj = None
|
||||
self.kv_a_proj_with_mqa = None
|
||||
self.q_a_proj = None
|
||||
self.q_a_layernorm = None
|
||||
self.q_b_proj = None
|
||||
self.q_proj = None
|
||||
if self.is_v32:
|
||||
if self.q_lora_rank is not None:
|
||||
self.fused_qkv_a_proj = MergedColumnParallelLinear(
|
||||
self.hidden_size, [
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim
|
||||
],
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.fused_qkv_a_proj",
|
||||
disable_tp=True)
|
||||
self.fused_qkv_a_proj.no_need_cross = True
|
||||
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
|
||||
else:
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
if self.is_v32:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
else:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn_supa'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
if self.is_v32:
|
||||
self.indexer: Optional[SupaIndexer] = SupaIndexer(
|
||||
vllm_config, config, hidden_size, q_lora_rank, quant_config,
|
||||
cache_config, topk_indices_buffer, f"{prefix}.indexer")
|
||||
else:
|
||||
self.indexer: Optional[SupaIndexer] = None
|
||||
|
||||
mla_modules = SupaMLAModules(
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
rotary_emb=self.rotary_emb,
|
||||
o_proj=self.o_proj,
|
||||
fused_qkv_a_proj=self.fused_qkv_a_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
q_a_layernorm=self.q_a_layernorm,
|
||||
q_b_proj=self.q_b_proj,
|
||||
q_proj=self.q_proj,
|
||||
indexer=self.indexer,
|
||||
is_sparse=self.is_v32,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
q_a_proj=self.q_a_proj,
|
||||
)
|
||||
|
||||
self.mla_attn = SupaMultiHeadLatentAttention(
|
||||
self.hidden_size,
|
||||
self.num_local_heads,
|
||||
self.scaling,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
self.v_head_dim,
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
mla_modules,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.mla_attn(positions, hidden_states, is_ds_v32=self.is_v32)
|
||||
|
||||
|
||||
def indexer_k_cache(
|
||||
k: torch.Tensor, # [num_tokens, head_dim] # (8, 128)
|
||||
kv_cache: torch.
|
||||
Tensor, # [1, num_blocks, block_size, cache_stride] # (1, 1024, 2048, 128)
|
||||
slot_mapping: torch.Tensor, # [num_tokens] # (8)
|
||||
) -> None:
|
||||
num_tokens = k.shape[0]
|
||||
head_dim = k.shape[1]
|
||||
|
||||
# [TODO] kv_cache shape is not aligned with nv
|
||||
cache_block_size = kv_cache.shape[-2]
|
||||
|
||||
for idx in range(num_tokens):
|
||||
slot_idx = slot_mapping[idx]
|
||||
k_idx = k[idx]
|
||||
block_idx = slot_idx // cache_block_size
|
||||
block_offset = slot_idx % cache_block_size
|
||||
kv_cache[0][block_idx][
|
||||
block_offset][:
|
||||
head_dim] = k_idx # [TODO] kv cache stride is longer than head_dim
|
||||
|
||||
|
||||
def bf16_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
cu_seqlen_ks: torch.Tensor,
|
||||
cu_seqlen_ke: torch.Tensor,
|
||||
):
|
||||
seq_len_kv = kv.shape[0]
|
||||
|
||||
k = kv
|
||||
q = q.float()
|
||||
k = k.float()
|
||||
|
||||
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
>= cu_seqlen_ks[:, None])
|
||||
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
|
||||
< cu_seqlen_ke[:, None])
|
||||
|
||||
mask = mask_lo & mask_hi
|
||||
score = torch.einsum("mhd,nd->hmn", q, k)
|
||||
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
|
||||
logits = logits.masked_fill(~mask, float("-inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def _ref_fp8_paged_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
max_model_len: int,
|
||||
):
|
||||
batch_size, next_n, _, _ = q.size()
|
||||
_, num_block, block_size, unkonw_size, head_dim = kv_cache.size(
|
||||
) # [1, num_block, block_size, _]
|
||||
num_block = num_block * 16
|
||||
block_size = block_size // 16
|
||||
kv_cache = kv_cache.view(num_block, block_size, unkonw_size, head_dim)
|
||||
logits = torch.full(
|
||||
[batch_size * next_n, max_model_len],
|
||||
float("-inf"),
|
||||
device=q.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
context_lens_list = context_lens.tolist()
|
||||
for i in range(batch_size):
|
||||
context_len = context_lens_list[i]
|
||||
q_offsets = torch.arange(context_len - next_n,
|
||||
context_len,
|
||||
device="cuda")
|
||||
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
|
||||
0, 1).contiguous())
|
||||
for block_rk in range(cdiv(context_len, block_size)):
|
||||
block_idx = block_tables[i][block_rk]
|
||||
qx, kx = q[i], kv_cache[block_idx]
|
||||
k_offsets = torch.arange(
|
||||
block_rk * block_size,
|
||||
(block_rk + 1) * block_size,
|
||||
device="cuda",
|
||||
)
|
||||
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
|
||||
<= q_offsets[:, None])
|
||||
s = torch.where(
|
||||
mask[None, :, :],
|
||||
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
|
||||
logits.dtype),
|
||||
float("-inf"),
|
||||
)
|
||||
s = torch.relu(s) * weight_slice[..., None]
|
||||
s = s.sum(dim=0)
|
||||
logits[
|
||||
i * next_n:(i + 1) * next_n,
|
||||
block_rk * block_size:(block_rk + 1) * block_size,
|
||||
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
|
||||
float("-inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def cp_gather_indexer_k_quant_cache(
|
||||
kv_cache, # [1, num_blocks, block_size, head_dim + 1]
|
||||
dst_value, # [cu_seq_lens[-1], head_dim]
|
||||
dst_scale, # [cu_seq_lens[-1], 4]
|
||||
block_table, # [batch_size, num_blocks]
|
||||
cu_seq_lens, # [batch_size + 1, ]
|
||||
batch_size,
|
||||
):
|
||||
_, num_blocks, block_size, _ = kv_cache.shape
|
||||
# align to nv
|
||||
num_blocks = num_blocks * 16
|
||||
block_size = block_size // 16
|
||||
head_dim = dst_value.shape[-1]
|
||||
kv_cache = kv_cache.view(num_blocks, -1)
|
||||
|
||||
expected_value = []
|
||||
# expected_scale = []
|
||||
for b in range(batch_size):
|
||||
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
|
||||
if s == 0:
|
||||
continue
|
||||
tot = cdiv(s, block_size)
|
||||
blocks = block_table[b, :tot]
|
||||
|
||||
value = []
|
||||
full_block = torch.arange(tot - 1,
|
||||
device=kv_cache.device,
|
||||
dtype=torch.int32)
|
||||
# [TODO] not support index in tensor on br, run in cpu now
|
||||
non_remaining_value = kv_cache.cpu()[
|
||||
blocks.cpu()[full_block.cpu()], :block_size * head_dim].view(
|
||||
-1, head_dim)
|
||||
# non_remaining_scale = kv_cache[blocks[full_block],
|
||||
# block_size * head_dim:].view(-1, 4)
|
||||
|
||||
remaining = s - (tot - 1) * block_size
|
||||
|
||||
value = torch.cat([
|
||||
non_remaining_value,
|
||||
kv_cache.cpu()[blocks[-1], :remaining * head_dim].view(
|
||||
-1, head_dim)
|
||||
],
|
||||
dim=0)
|
||||
# scale = torch.cat([
|
||||
# non_remaining_scale,
|
||||
# kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
|
||||
# remaining * 4].view(-1, 4)
|
||||
# ],
|
||||
# dim=0)
|
||||
|
||||
expected_value.append(value)
|
||||
# expected_scale.append(scale)
|
||||
|
||||
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
|
||||
# gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
|
||||
gather_value = gather_value.view(torch.bfloat16).to(dst_value.device)
|
||||
# gather_scale = gather_scale.view(torch.float32)
|
||||
dst_value.copy_(gather_value)
|
||||
# dst_scale.copy_(gather_scale)
|
||||
|
||||
|
||||
def sparse_attn_indexer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# profile run
|
||||
# NOTE(Chen): create the max possible flattened_kv. So that
|
||||
# profile_run can get correct memory usage.
|
||||
support_fp8 = False
|
||||
if support_fp8:
|
||||
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
||||
device=k.device,
|
||||
dtype=torch.uint8)
|
||||
_k_fp8 = _flattened_kv[..., :head_dim].view(
|
||||
torch.float8_e4m3fn).contiguous()
|
||||
_k_scale = _flattened_kv[...,
|
||||
head_dim:].view(torch.float32).contiguous()
|
||||
else:
|
||||
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
|
||||
device=k.device,
|
||||
dtype=torch.bfloat16)
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
def sparse_attn_indexer(
|
||||
hidden_states: torch.Tensor,
|
||||
k_cache_prefix: str,
|
||||
kv_cache: torch.Tensor,
|
||||
q_fp8: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
quant_block_size: int,
|
||||
scale_fmt: Optional[str],
|
||||
topk_tokens: int,
|
||||
head_dim: int,
|
||||
max_model_len: int,
|
||||
total_seq_lens: int,
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# careful! this will be None in dummy run
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# assert isinstance(attn_metadata, dict)
|
||||
if not isinstance(attn_metadata, dict):
|
||||
return sparse_attn_indexer_fake(
|
||||
hidden_states,
|
||||
k_cache_prefix,
|
||||
kv_cache,
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
quant_block_size,
|
||||
scale_fmt,
|
||||
topk_tokens,
|
||||
head_dim,
|
||||
max_model_len,
|
||||
total_seq_lens,
|
||||
topk_indices_buffer,
|
||||
)
|
||||
|
||||
assert topk_indices_buffer is not None
|
||||
|
||||
attn_metadata = attn_metadata[k_cache_prefix]
|
||||
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
|
||||
indexer_k_cache(
|
||||
k,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
topk_indices_buffer[:hidden_states.shape[1]] = -1
|
||||
if has_prefill:
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
for chunk in prefill_metadata.chunks:
|
||||
k_bf16 = torch.empty([chunk.total_seq_lens, head_dim],
|
||||
device=k.device,
|
||||
dtype=torch.bfloat16)
|
||||
k_scale = None
|
||||
cp_gather_indexer_k_quant_cache(
|
||||
kv_cache,
|
||||
k_bf16,
|
||||
k_scale,
|
||||
chunk.block_table,
|
||||
chunk.cu_seq_lens,
|
||||
chunk.num_reqs,
|
||||
)
|
||||
|
||||
logits = bf16_mqa_logits(
|
||||
q_fp8[chunk.token_start:chunk.token_end],
|
||||
k_bf16,
|
||||
weights[chunk.token_start:chunk.token_end],
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
)
|
||||
|
||||
# [TODO] topk is not aligned with cpu if elements are -inf
|
||||
topk_indices = logits.cpu().topk(min(topk_tokens,
|
||||
logits.shape[-1]),
|
||||
dim=-1)[1].supa()
|
||||
topk_indices -= chunk.cu_seqlen_ks[:, None]
|
||||
mask_lo = topk_indices >= 0
|
||||
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
|
||||
chunk.cu_seqlen_ks)[:, None] < 0
|
||||
mask = torch.full_like(topk_indices,
|
||||
False,
|
||||
dtype=torch.bool,
|
||||
device=topk_indices.device)
|
||||
mask = mask_lo & mask_hi
|
||||
topk_indices = topk_indices.masked_fill(~mask, -1)
|
||||
topk_indices_buffer[
|
||||
chunk.token_start:chunk.token_end, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
# we only have [num_block, block_size, head_dim],
|
||||
kv_cache = kv_cache.unsqueeze(-2)
|
||||
decode_lens = decode_metadata.decode_lens
|
||||
if decode_metadata.requires_padding:
|
||||
# pad in edge case where we have short chunked prefill length <
|
||||
# decode_threshold since we unstrictly split
|
||||
# prefill and decode by decode_threshold
|
||||
# (currently set to 1 + speculative tokens)
|
||||
padded_q_fp8_decode_tokens = pack_seq_triton(
|
||||
q_fp8[:num_decode_tokens], decode_lens)
|
||||
else:
|
||||
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
|
||||
decode_lens.shape[0], -1, *q_fp8.shape[1:])
|
||||
# TODO: move and optimize below logic with triton kernels
|
||||
batch_size = padded_q_fp8_decode_tokens.shape[0]
|
||||
next_n = padded_q_fp8_decode_tokens.shape[1]
|
||||
assert batch_size == decode_metadata.seq_lens.shape[0]
|
||||
num_padded_tokens = batch_size * next_n
|
||||
logits = _ref_fp8_paged_mqa_logits(
|
||||
padded_q_fp8_decode_tokens,
|
||||
kv_cache,
|
||||
weights[:num_padded_tokens],
|
||||
decode_metadata.seq_lens,
|
||||
decode_metadata.block_table,
|
||||
max_model_len=max_model_len,
|
||||
)
|
||||
# padded query len
|
||||
current_device = padded_q_fp8_decode_tokens.device
|
||||
padded_num_tokens = batch_size * next_n
|
||||
positions = torch.arange(max_model_len,
|
||||
device=current_device).unsqueeze(0).expand(
|
||||
batch_size * next_n, -1)
|
||||
row_indices = torch.arange(padded_num_tokens,
|
||||
device=current_device) // next_n
|
||||
next_n_offset = torch.arange(
|
||||
padded_num_tokens,
|
||||
device=padded_q_fp8_decode_tokens.device) % next_n
|
||||
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
|
||||
next_n_offset).unsqueeze(1)
|
||||
# index_end_pos: [B * N, 1]
|
||||
mask = positions <= index_end_pos
|
||||
# mask: [B * N, L]
|
||||
logits = logits.masked_fill(~mask, float('-inf'))
|
||||
# [TODO] topk is not supported
|
||||
device = logits.device
|
||||
logits = logits.to('cpu')
|
||||
topk_indices = logits.topk(topk_tokens,
|
||||
dim=-1)[1].to(torch.int32) # [B * N, K]
|
||||
topk_indices = topk_indices.to(device)
|
||||
# ensure we don't set indices for the top k
|
||||
# that is out of range(masked already)
|
||||
# this will happen if context length is shorter than K
|
||||
topk_indices[topk_indices > index_end_pos] = -1
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
# the topk indices removing padded tokens
|
||||
topk_indices = unpack_seq_triton(
|
||||
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
|
||||
decode_lens)
|
||||
topk_indices_buffer[:num_decode_tokens, :topk_indices.
|
||||
shape[-1]] = topk_indices.to(dtype=torch.int32)
|
||||
|
||||
return topk_indices_buffer
|
||||
|
||||
|
||||
class SupaIndexer(Indexer):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
config: Union[DeepseekV2Config, DeepseekV3Config],
|
||||
hidden_size: int,
|
||||
q_lora_rank: Optional[int],
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
cache_config: Optional[CacheConfig],
|
||||
topk_indices_buffer: Optional[torch.Tensor] = None,
|
||||
prefix: str = "") -> None:
|
||||
super().__init__(
|
||||
vllm_config=vllm_config,
|
||||
config=config,
|
||||
hidden_size=hidden_size,
|
||||
q_lora_rank=q_lora_rank,
|
||||
quant_config=quant_config,
|
||||
cache_config=cache_config,
|
||||
topk_indices_buffer=topk_indices_buffer,
|
||||
prefix=prefix,
|
||||
)
|
||||
self.n_head = config.index_n_heads # 64
|
||||
self.weights_proj = ReplicatedLinear(hidden_size,
|
||||
self.n_head,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.weights_proj")
|
||||
self.k_cache.dtype = torch.bfloat16
|
||||
self.k_cache.head_dim = config.index_head_dim
|
||||
self.topk_indices_buffer.fill_(0)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
|
||||
rotary_emb) -> torch.Tensor:
|
||||
q, _ = self.wq_b(qr)
|
||||
q = q.view(-1, self.n_head, self.head_dim)
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
k, _ = self.wk(hidden_states)
|
||||
k = k.view(-1, self.head_dim)
|
||||
k = self.k_norm(k)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
|
||||
|
||||
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
|
||||
q = torch.cat([q_pe, q_nope], dim=-1)
|
||||
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
|
||||
|
||||
# we only quant q here since k quant is fused with cache insertion
|
||||
q = q.view(-1, self.head_dim)
|
||||
support_fp8 = False
|
||||
if support_fp8:
|
||||
q_fp8, q_scale = per_token_group_quant_fp8(
|
||||
q,
|
||||
self.quant_block_size,
|
||||
column_major_scales=False,
|
||||
use_ue8m0=self.scale_fmt is not None)
|
||||
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
|
||||
q_scale = q_scale.view(-1, self.n_head, 1)
|
||||
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = weights.unsqueeze(
|
||||
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return torch.ops.vllm.sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q_fp8,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
else:
|
||||
q = q.view(-1, self.n_head, self.head_dim)
|
||||
weights, _ = self.weights_proj(hidden_states)
|
||||
weights = weights.view(-1, self.n_head)
|
||||
weights = weights.unsqueeze(
|
||||
-1) * self.softmax_scale * self.n_head**-0.5
|
||||
weights = weights.squeeze(-1)
|
||||
|
||||
return sparse_attn_indexer(
|
||||
hidden_states,
|
||||
self.k_cache.prefix,
|
||||
self.k_cache.kv_cache[0],
|
||||
q,
|
||||
k,
|
||||
weights,
|
||||
self.quant_block_size,
|
||||
self.scale_fmt,
|
||||
self.topk_tokens,
|
||||
self.head_dim,
|
||||
self.max_model_len,
|
||||
self.max_total_seq_len,
|
||||
self.topk_indices_buffer,
|
||||
)
|
||||
|
||||
|
||||
@patch_to(DeepseekV2Model)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0)
|
||||
if hidden_states is not None else hidden_states,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
@patch_to(DeepseekV2ForCausalLM)
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(DeepseekV2ForCausalLM, self).__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
model_config.use_ds_mla = True
|
||||
is_v32 = hasattr(config, "index_topk")
|
||||
if is_v32:
|
||||
model_config.use_ds_mla_sparse = True
|
||||
quant_config = vllm_config.quant_config
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
# quant_method for relevant layers during initialization.
|
||||
self.fuse_qkv_a_proj = hasattr(
|
||||
config, "q_lora_rank") and config.q_lora_rank is not None
|
||||
if self.fuse_qkv_a_proj:
|
||||
self.packed_modules_mapping["fused_qkv_a_proj"] = [
|
||||
"q_a_proj",
|
||||
"kv_a_proj_with_mqa",
|
||||
]
|
||||
|
||||
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "lm_head"),
|
||||
)
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
|
||||
@patch_to(DeepseekV2ForCausalLM)
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
# QKV fusion is optional, fall back to normal
|
||||
# weight loading if it's not enabled
|
||||
# if go with fusion option, then update name
|
||||
if ((param_name == "fused_qkv_a_proj")
|
||||
and name_mapped not in params_dict):
|
||||
continue
|
||||
else:
|
||||
name = name_mapped
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if name not in params_dict:
|
||||
# logger.debug(f'skip {name}')
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
torch.supa.empty_cache()
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
# logger.debug(f'skip {name}')
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
torch.supa.empty_cache()
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
# logger.debug(f'skip {name}')
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
torch.supa.empty_cache()
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
vllm.model_executor.models.deepseek_v2.DeepseekV2MLP = MergedGateUpMLPSiluL2
|
||||
logger.debug('[Patch] patch DeepSeekV2 MLP with MergedGateUpMLPSiluL2')
|
||||
vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = DeepseekV2MoE
|
||||
logger.debug('[Patch] patch DeepSeekV2 MoE with DeepseekV2MoE')
|
||||
vllm.model_executor.models.deepseek_v2.DeepseekV2MLAAttention = SupaDeepseekV2MLAAttention
|
||||
logger.debug('[Patch] patch DeepSeekV2 MLA with SupaDeepseekV2MLAAttention')
|
||||
vllm.model_executor.models.deepseek_v2.Indexer = SupaIndexer
|
||||
logger.debug('[Patch] patch DeepSeekV2 Indexer with SupaIndexer')
|
||||
vllm.model_executor.models.deepseek_v2.MultiHeadLatentAttention = SupaMultiHeadLatentAttention
|
||||
logger.debug(
|
||||
'[Patch] patch DeepSeekV2 MultiHeadLatentAttention with SupaMultiHeadLatentAttention'
|
||||
)
|
||||
|
||||
# vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.packed_modules_mapping = {
|
||||
# "gate_up_proj": ["gate_proj", "up_proj"],
|
||||
# # "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
|
||||
# }
|
||||
# logger.debug(
|
||||
# '[Patch] patch DeepseekV2ForCausalLM with SupportsQuant packed_modules_mapping'
|
||||
# )
|
||||
299
vllm_br/model_executor/models/glm4.py
Normal file
299
vllm_br/model_executor/models/glm4.py
Normal file
@@ -0,0 +1,299 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright 2025 The Zhipu AI team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
|
||||
#from typing import Any, Optional
|
||||
#
|
||||
#import torch
|
||||
#from fastcore.basics import patch_to
|
||||
#from transformers import Glm4Config
|
||||
#
|
||||
#import vllm
|
||||
#from vllm.attention import Attention, AttentionType
|
||||
#from vllm.config import CacheConfig
|
||||
#from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
#from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
# RowParallelLinear)
|
||||
#from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
#from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding,
|
||||
# RotaryEmbedding)
|
||||
#from vllm.model_executor.models.glm4 import Glm4Attention
|
||||
#
|
||||
#_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
|
||||
|
||||
#def get_rope_0_9_2(
|
||||
# head_size: int,
|
||||
# rotary_dim: int,
|
||||
# max_position: int,
|
||||
# base: float,
|
||||
# is_neox_style: bool = True,
|
||||
# rope_scaling: Optional[dict[str, Any]] = None,
|
||||
# dtype: Optional[torch.dtype] = None,
|
||||
# partial_rotary_factor: float = 1.0,
|
||||
# dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
#) -> RotaryEmbedding:
|
||||
#
|
||||
# if dtype is None:
|
||||
# dtype = torch.get_default_dtype()
|
||||
# if rope_scaling is not None:
|
||||
# # Transforms every value that is a list into a tuple for caching calls
|
||||
# rope_scaling_tuple = {
|
||||
# k: tuple(v) if isinstance(v, list) else v
|
||||
# for k, v in rope_scaling.items()
|
||||
# }
|
||||
# rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
# else:
|
||||
# rope_scaling_args = None
|
||||
#
|
||||
# if dual_chunk_attention_config is not None:
|
||||
# dual_chunk_attention_tuple = {
|
||||
# k: tuple(v) if isinstance(v, list) else v
|
||||
# for k, v in dual_chunk_attention_config.items()
|
||||
# if k != "sparse_attention_config"
|
||||
# }
|
||||
# dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
|
||||
# else:
|
||||
# dual_chunk_attention_args = None
|
||||
#
|
||||
# if partial_rotary_factor < 1.0:
|
||||
# rotary_dim = int(rotary_dim * partial_rotary_factor)
|
||||
# key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
# rope_scaling_args, dual_chunk_attention_args, dtype)
|
||||
# if key in _ROPE_DICT:
|
||||
# return _ROPE_DICT[key]
|
||||
#
|
||||
# if dual_chunk_attention_config is not None:
|
||||
# extra_kwargs = {
|
||||
# k: v
|
||||
# for k, v in dual_chunk_attention_config.items()
|
||||
# if k in ("chunk_size", "local_size")
|
||||
# }
|
||||
# rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
|
||||
# max_position, base,
|
||||
# is_neox_style, dtype,
|
||||
# **extra_kwargs)
|
||||
# elif not rope_scaling:
|
||||
# rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
# is_neox_style, dtype)
|
||||
# else:
|
||||
# scaling_type = rope_scaling["rope_type"]
|
||||
#
|
||||
# if scaling_type == "llama3":
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# low_freq_factor = rope_scaling["low_freq_factor"]
|
||||
# high_freq_factor = rope_scaling["high_freq_factor"]
|
||||
# original_max_position = rope_scaling[
|
||||
# "original_max_position_embeddings"]
|
||||
# rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
|
||||
# max_position, base,
|
||||
# is_neox_style, dtype,
|
||||
# scaling_factor, low_freq_factor,
|
||||
# high_freq_factor,
|
||||
# original_max_position)
|
||||
# elif scaling_type == "mllama4":
|
||||
# rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
|
||||
# max_position, base,
|
||||
# is_neox_style, dtype)
|
||||
# elif scaling_type == "default":
|
||||
# if "mrope_section" in rope_scaling:
|
||||
# rotary_emb = MRotaryEmbedding(
|
||||
# head_size,
|
||||
# rotary_dim,
|
||||
# max_position,
|
||||
# base,
|
||||
# is_neox_style,
|
||||
# dtype,
|
||||
# mrope_section=rope_scaling["mrope_section"],
|
||||
# )
|
||||
# else:
|
||||
# rotary_emb = RotaryEmbedding(
|
||||
# head_size,
|
||||
# rotary_dim,
|
||||
# max_position,
|
||||
# base,
|
||||
# is_neox_style,
|
||||
# dtype,
|
||||
# )
|
||||
# elif scaling_type == "linear":
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
# max_position, base,
|
||||
# is_neox_style,
|
||||
# scaling_factor, dtype)
|
||||
# elif scaling_type == "ntk":
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# mixed_b = rope_scaling.get('mixed_b', None)
|
||||
# rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
# max_position, base,
|
||||
# is_neox_style,
|
||||
# scaling_factor, dtype,
|
||||
# mixed_b)
|
||||
# elif scaling_type == "dynamic":
|
||||
# if "alpha" in rope_scaling:
|
||||
# scaling_alpha = rope_scaling["alpha"]
|
||||
# rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
||||
# head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
# scaling_alpha, dtype)
|
||||
# elif "factor" in rope_scaling:
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
# head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
# scaling_factor, dtype)
|
||||
# else:
|
||||
# raise ValueError("Dynamic rope scaling must contain either "
|
||||
# "'alpha' or 'factor' field")
|
||||
# elif scaling_type == "yarn":
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# original_max_position = rope_scaling[
|
||||
# "original_max_position_embeddings"]
|
||||
# extra_kwargs = {
|
||||
# k: v
|
||||
# for k, v in rope_scaling.items()
|
||||
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
# "beta_slow")
|
||||
# }
|
||||
# rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
# original_max_position,
|
||||
# base, is_neox_style,
|
||||
# scaling_factor, dtype,
|
||||
# **extra_kwargs)
|
||||
# elif scaling_type == "deepseek_yarn":
|
||||
# scaling_factor = rope_scaling["factor"]
|
||||
# original_max_position = rope_scaling[
|
||||
# "original_max_position_embeddings"]
|
||||
# # assert max_position == original_max_position * scaling_factor
|
||||
# extra_kwargs = {
|
||||
# k: v
|
||||
# for k, v in rope_scaling.items()
|
||||
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
|
||||
# "beta_slow", "mscale", "mscale_all_dim")
|
||||
# }
|
||||
# rotary_emb = DeepseekScalingRotaryEmbedding(
|
||||
# head_size, rotary_dim, original_max_position, base,
|
||||
# is_neox_style, scaling_factor, dtype, **extra_kwargs)
|
||||
# elif scaling_type == "longrope":
|
||||
# short_factor = rope_scaling["short_factor"]
|
||||
# long_factor = rope_scaling["long_factor"]
|
||||
# original_max_position = rope_scaling[
|
||||
# "original_max_position_embeddings"]
|
||||
# extra_kwargs = {
|
||||
# k: v
|
||||
# for k, v in rope_scaling.items()
|
||||
# if k in ("short_mscale", "long_mscale")
|
||||
# }
|
||||
# rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
|
||||
# head_size, rotary_dim, max_position, original_max_position,
|
||||
# base, is_neox_style, dtype, short_factor, long_factor,
|
||||
# **extra_kwargs)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
# _ROPE_DICT[key] = rotary_emb
|
||||
# return rotary_emb
|
||||
#
|
||||
#
|
||||
#@patch_to(vllm.model_executor.models.glm4.Glm4Attention)
|
||||
#def __init__(self,
|
||||
# config: Glm4Config,
|
||||
# hidden_size: int,
|
||||
# num_heads: int,
|
||||
# num_kv_heads: int,
|
||||
# max_position: int = 4096 * 32,
|
||||
# head_dim: Optional[int] = None,
|
||||
# qkv_bias: bool = False,
|
||||
# rope_theta: float = 10000,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
# quant_config: Optional[QuantizationConfig] = None,
|
||||
# rope_scaling: Optional[tuple] = None,
|
||||
# prefix: str = "",
|
||||
# attn_type: str = AttentionType.DECODER) -> None:
|
||||
# super(Glm4Attention, self).__init__()
|
||||
# self.hidden_size = hidden_size
|
||||
# tp_size = get_tensor_model_parallel_world_size()
|
||||
# self.total_num_heads = num_heads
|
||||
# assert self.total_num_heads % tp_size == 0
|
||||
# self.num_heads = self.total_num_heads // tp_size
|
||||
# self.total_num_kv_heads = num_kv_heads
|
||||
# if self.total_num_kv_heads >= tp_size:
|
||||
# # Number of KV heads is greater than TP size, so we partition
|
||||
# # the KV heads across multiple tensor parallel GPUs.
|
||||
# assert self.total_num_kv_heads % tp_size == 0
|
||||
# else:
|
||||
# # Number of KV heads is less than TP size, so we replicate
|
||||
# # the KV heads across multiple tensor parallel GPUs.
|
||||
# assert tp_size % self.total_num_kv_heads == 0
|
||||
# partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
|
||||
# self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
# self.rotary_dim = self.head_dim
|
||||
# self.q_size = self.num_heads * self.head_dim
|
||||
# self.kv_size = self.num_kv_heads * self.head_dim
|
||||
# self.scaling = self.head_dim**-0.5
|
||||
# self.rope_theta = rope_theta
|
||||
# self.qkv_proj = QKVParallelLinear(
|
||||
# hidden_size,
|
||||
# self.head_dim,
|
||||
# self.total_num_heads,
|
||||
# self.total_num_kv_heads,
|
||||
# bias=qkv_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv_proj",
|
||||
# )
|
||||
# self.o_proj = RowParallelLinear(
|
||||
# self.total_num_heads * self.head_dim,
|
||||
# hidden_size,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.o_proj",
|
||||
# )
|
||||
# self.rotary_emb = get_rope_0_9_2(
|
||||
# self.head_dim,
|
||||
# rotary_dim=self.rotary_dim,
|
||||
# max_position=max_position,
|
||||
# base=self.rope_theta,
|
||||
# rope_scaling=rope_scaling,
|
||||
# partial_rotary_factor=partial_rotary_factor,
|
||||
# is_neox_style=False,
|
||||
# )
|
||||
# self.attn = Attention(self.num_heads,
|
||||
# self.head_dim,
|
||||
# self.scaling,
|
||||
# num_kv_heads=self.num_kv_heads,
|
||||
# cache_config=cache_config,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.attn",
|
||||
# attn_type=attn_type)
|
||||
795
vllm_br/model_executor/models/glm4_1v.py
Normal file
795
vllm_br/model_executor/models/glm4_1v.py
Normal file
@@ -0,0 +1,795 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
from einops import rearrange, repeat
|
||||
from torch_br.contrib import SueagerScaledDotProductAttention
|
||||
|
||||
import vllm
|
||||
import vllm.model_executor.models.glm4
|
||||
import vllm.model_executor.models.llama
|
||||
import vllm.model_executor.models.qwen2_vl
|
||||
import vllm_br.envs as br_envs
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
parallel_state)
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.glm4_1v import (Glm4vForConditionalGeneration,
|
||||
Glm4vVisionBlock,
|
||||
Glm4vVisionMLP,
|
||||
Glm4vVisionTransformer)
|
||||
from vllm.model_executor.models.utils import (init_vllm_registered_model,
|
||||
is_pp_missing_parameter,
|
||||
maybe_prefix)
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from ..layers.activation import SiluAndMul
|
||||
from ..layers.br_utils import is_br166_device
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def Glm4vVisionMLP_init_fit(self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
bias: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super(Glm4vVisionMLP, self).__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
|
||||
def Glm4vVisionMLP_forward_fit(self, x: torch.Tensor):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
|
||||
#x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(
|
||||
gathered_tensors,
|
||||
local_tensor,
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1)
|
||||
for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
|
||||
]
|
||||
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
||||
return result_tensor
|
||||
|
||||
|
||||
class Glm4vVisionAttention_fit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = (0 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_rank())
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
|
||||
#self.qkv = QKVParallelLinear(
|
||||
# hidden_size=embed_dim,
|
||||
# head_size=self.hidden_size_per_attention_head,
|
||||
# total_num_heads=num_heads,
|
||||
# total_num_kv_heads=num_heads,
|
||||
# bias=False,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv",
|
||||
#)
|
||||
#self.proj = RowParallelLinear(
|
||||
# input_size=projection_size,
|
||||
# output_size=embed_dim,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.proj",
|
||||
# bias=False,
|
||||
#)
|
||||
qkv_output_size = (num_heads +
|
||||
2 * num_heads) * self.hidden_size_per_attention_head
|
||||
self.qkv = nn.Linear(embed_dim, qkv_output_size, bias=False)
|
||||
self.proj = nn.Linear(projection_size, embed_dim, bias=False)
|
||||
self.sueager_attention = SueagerScaledDotProductAttention()
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
# self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"GLM-4V does not support {self.attn_backend} backend now.")
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, _ = qkv.shape
|
||||
if self.tp_size > 1:
|
||||
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
|
||||
self.tp_size)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=2)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(
|
||||
dist_utils.split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size,
|
||||
)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (
|
||||
seq_len,
|
||||
bs,
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
# x, _ = self.qkv(x)
|
||||
x = self.qkv(x)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
batch_size = q.shape[1]
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
|
||||
if rotary_pos_emb is not None:
|
||||
q = glm_apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
||||
k = glm_apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
||||
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
output = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seqlen,
|
||||
max_seqlen_k=max_seqlen,
|
||||
dropout_p=0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
context_layer = rearrange(output,
|
||||
"(b s) ... -> b s ...",
|
||||
b=batch_size)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
outputs = []
|
||||
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
start_idx = cu_seqlens[i - 1]
|
||||
end_idx = cu_seqlens[i]
|
||||
q_i = q[:, start_idx:end_idx]
|
||||
k_i = k[:, start_idx:end_idx]
|
||||
v_i = v[:, start_idx:end_idx]
|
||||
q_i, k_i, v_i = (rearrange(x, "b s h d -> s h b d")
|
||||
for x in [q_i, k_i, v_i])
|
||||
output_i = torch_br.sueager_scaled_dot_product_attention_fwd(
|
||||
q_i.squeeze(),
|
||||
k_i.squeeze(),
|
||||
v_i.squeeze(),
|
||||
mask=None,
|
||||
dropout_prob=0.0,
|
||||
is_causal=False,
|
||||
scale=1 / math.sqrt(q_i.shape[-1]),
|
||||
algorithm="FMHA",
|
||||
)[0]
|
||||
output_i = output_i.unsqueeze(0)
|
||||
if is_br166_device():
|
||||
output_tmp = torch_br._empty_ut_only(output_i.shape,
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="BB",
|
||||
axis=0,
|
||||
dtype=torch.bfloat16)
|
||||
output_tmp.copy_(output_i)
|
||||
output_i = output_tmp
|
||||
output_i = rearrange(output_i, "b s h d -> h b s d")
|
||||
outputs.append(output_i)
|
||||
context_layer = torch.cat(outputs, dim=1)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
|
||||
kv_seqlen=None,
|
||||
device=q.device)
|
||||
|
||||
context_layer = xops.memory_efficient_attention_forward(
|
||||
q, k, v, attn_bias=attn_bias, p=0, scale=None)
|
||||
|
||||
context_layer = rearrange(context_layer,
|
||||
"b s h d -> s b (h d)").contiguous()
|
||||
|
||||
# output, _ = self.proj(context_layer)
|
||||
output = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
def Glm4vVisionBlock_init_fit(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_hidden_dim: int,
|
||||
norm_layer: Optional[Callable[[int], nn.Module]] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super(Glm4vVisionBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.attn = Glm4vVisionAttention_fit(
|
||||
embed_dim=dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
self.mlp = Glm4vVisionMLP(
|
||||
dim,
|
||||
mlp_hidden_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
|
||||
def Glm4vVisionBlock_forward_fit(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
#from fpdb import ForkedPdb
|
||||
|
||||
normx = self.norm1(x)
|
||||
cur_device = torch.supa.current_device()
|
||||
x = x + self.attn(
|
||||
normx,
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb.to(cur_device),
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
def Llama_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
|
||||
split_params_mapping = [
|
||||
(".gate_up_proj", ".gate_proj", ".up_proj"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if (self.quant_config is not None
|
||||
and (scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight
|
||||
if loaded_weight.dim() == 0 else loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
if "scale" in name:
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
do_mapping_flag = False
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
do_mapping_flag = True
|
||||
loaded_params.add(name)
|
||||
break
|
||||
|
||||
if not do_mapping_flag:
|
||||
for gate_up, gate, up in split_params_mapping:
|
||||
if gate_up not in name:
|
||||
continue
|
||||
gate_name = name.replace(gate_up, gate)
|
||||
up_name = name.replace(gate_up, up)
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param_gate = params_dict[gate_name]
|
||||
param_up = params_dict[up_name]
|
||||
assert loaded_weight.shape[0] == param_gate.shape[
|
||||
0] + param_up.shape[0], "gate up shape is not match"
|
||||
|
||||
weight_loader_gate = param_gate.weight_loader
|
||||
weight_loader_gate(param_gate, loaded_weight[
|
||||
:param_gate.shape[0],
|
||||
])
|
||||
|
||||
weight_loader_up = param_up.weight_loader
|
||||
weight_loader_up(param_up, loaded_weight[
|
||||
param_gate.shape[0]:,
|
||||
])
|
||||
|
||||
do_mapping_flag = True
|
||||
loaded_params.add(gate_name)
|
||||
loaded_params.add(up_name)
|
||||
break
|
||||
|
||||
if not do_mapping_flag:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
||||
if not interleaved:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
else:
|
||||
x1, x2 = x[..., ::2], x[..., 1::2]
|
||||
return rearrange(torch.stack((-x2, x1), dim=-1),
|
||||
"... d two -> ... (d two)",
|
||||
two=2)
|
||||
|
||||
|
||||
def glm_apply_rotary_emb_torch(x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
interleaved: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
x: (batch_size, seqlen, nheads, headdim)
|
||||
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
||||
"""
|
||||
ro_dim = cos.shape[-1] * 2
|
||||
assert ro_dim <= x.shape[-1]
|
||||
cos = repeat(
|
||||
cos,
|
||||
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
sin = repeat(
|
||||
sin,
|
||||
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
|
||||
cos = cos.unsqueeze(2)
|
||||
sin = sin.unsqueeze(2)
|
||||
|
||||
res = torch.cat(
|
||||
[
|
||||
x[..., :ro_dim] * cos +
|
||||
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
def glm_apply_rotary_pos_emb_vision(t: torch.Tensor,
|
||||
freqs: torch.Tensor) -> torch.Tensor:
|
||||
t_ = t.float()
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
apply_rotary_emb = glm_apply_rotary_emb_torch
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
output = apply_rotary_emb(t_, cos, sin).type_as(t)
|
||||
return output
|
||||
|
||||
|
||||
def LlamaMLP_glm4_1v_forward(self, x):
|
||||
x, _ = self.gate_up_proj(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def Glm4Attention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if is_br166_device():
|
||||
q_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.q_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="SB",
|
||||
axis=2,
|
||||
dtype=torch.bfloat16)
|
||||
k_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="SB",
|
||||
axis=2,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
q_tmp.copy_(q)
|
||||
k_tmp.copy_(k)
|
||||
q = q_tmp
|
||||
k = k_tmp
|
||||
q_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.q_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="BB",
|
||||
axis=0,
|
||||
dtype=torch.bfloat16)
|
||||
k_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="BB",
|
||||
axis=0,
|
||||
dtype=torch.bfloat16)
|
||||
q_tmp.copy_(q)
|
||||
k_tmp.copy_(k)
|
||||
q = q_tmp
|
||||
k = k_tmp
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if is_br166_device():
|
||||
q_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.q_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="SB",
|
||||
axis=2,
|
||||
dtype=torch.bfloat16)
|
||||
k_tmp = torch_br._empty_ut_only(
|
||||
(qkv.shape[0], qkv.shape[1], self.kv_size),
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="SB",
|
||||
axis=2,
|
||||
dtype=torch.bfloat16)
|
||||
q_tmp.copy_(q)
|
||||
k_tmp.copy_(k)
|
||||
q = q_tmp
|
||||
k = k_tmp
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
max_image_tokens = self.get_max_image_tokens()
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
max_video_tokens = self.get_num_video_tokens(image_width=target_width,
|
||||
image_height=target_height,
|
||||
num_frames=1)
|
||||
return {"image": max_image_tokens, "video": max_video_tokens}
|
||||
|
||||
|
||||
def glm4v_init(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super(Glm4vForConditionalGeneration, self).__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
multimodal_config = vllm_config.model_config.multimodal_config
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
|
||||
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
use_data_parallel=self.use_data_parallel,
|
||||
)
|
||||
|
||||
if config.model_type == "glm4v":
|
||||
architectures = ["Glm4ForCausalLM"]
|
||||
elif config.model_type == "glm4v_moe":
|
||||
architectures = ["Glm4MoeForCausalLM"]
|
||||
else:
|
||||
architectures = None
|
||||
|
||||
self.language_model = init_vllm_registered_model(
|
||||
vllm_config=vllm_config,
|
||||
hf_config=config.text_config,
|
||||
prefix=maybe_prefix(prefix, "language_model"),
|
||||
architectures=architectures)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
br_envs.VLLM_BR_USE_MROPE_0_9_2 = True
|
||||
|
||||
|
||||
def Glm4vPatchMerger_forward(self, x: torch.Tensor):
|
||||
x, _ = self.proj(x)
|
||||
if is_br166_device():
|
||||
output_tmp = torch_br._empty_ut_only(x.shape,
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="BB",
|
||||
axis=0,
|
||||
dtype=torch.bfloat16)
|
||||
output_tmp.copy_(x)
|
||||
x = output_tmp
|
||||
x = self.extra_activation_func(self.post_projection_norm(x))
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
# x = self.act_fn(gate_up)
|
||||
x = gate_up
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def Glm4vVisionEmbeddings_forward(self, embeddings, lengths, image_shapes,
|
||||
h_coords, w_coords) -> torch.Tensor:
|
||||
pos_embed_weight = self.position_embedding.weight
|
||||
hidden_size = pos_embed_weight.shape[1]
|
||||
total_seq = h_coords.shape[0]
|
||||
device = pos_embed_weight.device
|
||||
|
||||
# Move coordinates to correct device
|
||||
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
|
||||
|
||||
# Handle empty sequence case
|
||||
if total_seq == 0:
|
||||
adapted_pos_embed = torch.empty(0,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=pos_embed_weight.dtype)
|
||||
else:
|
||||
# Convert inputs to tensors if needed
|
||||
if isinstance(lengths, list):
|
||||
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
|
||||
if not isinstance(image_shapes, torch.Tensor):
|
||||
image_shapes = torch.tensor(image_shapes,
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
|
||||
# Prepare 2D position embedding
|
||||
orig_size_sq = pos_embed_weight.shape[0]
|
||||
orig_size = int(orig_size_sq**0.5)
|
||||
pos_embed_2d = (pos_embed_weight.view(orig_size,
|
||||
orig_size, hidden_size).permute(
|
||||
2, 0, 1).unsqueeze(0))
|
||||
pos_embed_2d = pos_embed_2d.to(torch.float32)
|
||||
|
||||
# Calculate target dimensions for each patch
|
||||
# Add bounds checking for data parallel mode
|
||||
if len(lengths) > image_shapes.shape[0]:
|
||||
# In data parallel mode, some GPUs might not have all
|
||||
# image shapes
|
||||
# Use available image shapes, cycling if necessary
|
||||
target_h_list = []
|
||||
target_w_list = []
|
||||
for i in range(len(lengths)):
|
||||
# Cycle through available shapes
|
||||
shape_idx = i % image_shapes.shape[0]
|
||||
target_h_list.append(image_shapes[shape_idx,
|
||||
1].repeat(lengths[i]))
|
||||
target_w_list.append(image_shapes[shape_idx,
|
||||
2].repeat(lengths[i]))
|
||||
target_h = torch.cat(target_h_list).to(device=device,
|
||||
dtype=torch.float32)
|
||||
target_w = torch.cat(target_w_list).to(device=device,
|
||||
dtype=torch.float32)
|
||||
else:
|
||||
target_h = torch.cat([
|
||||
image_shapes[i, 1].repeat(lengths[i])
|
||||
for i in range(len(lengths))
|
||||
]).to(device=device, dtype=torch.float32)
|
||||
target_w = torch.cat([
|
||||
image_shapes[i, 2].repeat(lengths[i])
|
||||
for i in range(len(lengths))
|
||||
]).to(device=device, dtype=torch.float32)
|
||||
|
||||
# Normalize coordinates to [-1, 1] range for grid_sample
|
||||
h_coords = h_coords.to(device=device, dtype=torch.float32)
|
||||
w_coords = w_coords.to(device=device, dtype=torch.float32)
|
||||
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
|
||||
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
|
||||
|
||||
# Create sampling grid
|
||||
grid = (torch.stack((norm_w, norm_h),
|
||||
dim=-1).unsqueeze(0).unsqueeze(2))
|
||||
|
||||
# Perform bicubic interpolation
|
||||
interpolated_embed_fp32 = F.grid_sample(
|
||||
pos_embed_2d,
|
||||
grid,
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
padding_mode="border",
|
||||
)
|
||||
|
||||
# Reshape and convert back to original dtype
|
||||
adapted_pos_embed_fp32 = (
|
||||
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
|
||||
adapted_pos_embed = adapted_pos_embed_fp32.to(
|
||||
pos_embed_weight.dtype).to(embeddings.device)
|
||||
|
||||
# Add adapted position encoding to embeddings
|
||||
embeddings = embeddings + adapted_pos_embed
|
||||
return embeddings
|
||||
|
||||
|
||||
#LlamaModel.load_weights = Llama_load_weights
|
||||
vllm.model_executor.models.llama.LlamaMLP.forward = LlamaMLP_glm4_1v_forward
|
||||
vllm.model_executor.models.glm4.Glm4Attention.forward = Glm4Attention_forward
|
||||
#vllm.model_executor.models.glm4_1v.Glm4vVisionAttention = Glm4vVisionAttention_fit
|
||||
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.__init__ = Glm4vVisionBlock_init_fit
|
||||
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.forward = Glm4vVisionBlock_forward_fit
|
||||
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.forward = Glm4vVisionMLP_forward_fit
|
||||
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.__init__ = Glm4vVisionMLP_init_fit
|
||||
vllm.model_executor.models.glm4_1v.Glm4vProcessingInfo.get_mm_max_tokens_per_item = get_mm_max_tokens_per_item
|
||||
vllm.model_executor.models.glm4_1v.Glm4vForConditionalGeneration.__init__ = glm4v_init
|
||||
vllm.model_executor.models.glm4_1v.Glm4vPatchMerger.forward = Glm4vPatchMerger_forward
|
||||
vllm.model_executor.models.glm4_1v.Glm4vVisionEmbeddings.forward = Glm4vVisionEmbeddings_forward
|
||||
475
vllm_br/model_executor/models/glm4_moe.py
Normal file
475
vllm_br/model_executor/models/glm4_moe.py
Normal file
@@ -0,0 +1,475 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright 2025 The ZhipuAI Team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from torch import nn
|
||||
from transformers.models.glm4_moe import Glm4MoeConfig
|
||||
|
||||
import vllm
|
||||
import vllm.model_executor.models.glm4_moe
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.distributed import (get_ep_group, get_pp_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.glm4_moe import (
|
||||
Glm4MoeAttention, Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_br.v1.attention.backends.attention_v1 import (
|
||||
SUPAFlashAttentionMetadata)
|
||||
from .supa_module import MergedGateUpMLPSiluL2
|
||||
|
||||
|
||||
class Glm4MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = self.ep_group.rank()
|
||||
self.ep_size = self.ep_group.size()
|
||||
self.n_routed_experts: int = config.n_routed_experts
|
||||
self.n_shared_experts: int = config.n_shared_experts
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
params_dtype=torch.float32,
|
||||
prefix=f"{prefix}.gate")
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32))
|
||||
|
||||
# Load balancing settings.
|
||||
vllm_config = get_current_vllm_config()
|
||||
eplb_config = vllm_config.parallel_config.eplb_config
|
||||
self.enable_eplb = enable_eplb
|
||||
|
||||
self.n_redundant_experts = eplb_config.num_redundant_experts
|
||||
self.n_logical_experts = self.n_routed_experts
|
||||
self.n_physical_experts = (self.n_logical_experts +
|
||||
self.n_redundant_experts)
|
||||
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
|
||||
|
||||
self.physical_expert_start = (self.ep_rank *
|
||||
self.n_local_physical_experts)
|
||||
self.physical_expert_end = (self.physical_expert_start +
|
||||
self.n_local_physical_experts)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func="sigmoid",
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias,
|
||||
enable_eplb=self.enable_eplb,
|
||||
num_redundant_experts=self.n_redundant_experts)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = MergedGateUpMLPSiluL2(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
# reduce_results=self.experts.must_reduce_shared_expert_outputs(
|
||||
# ),
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
orig_shape = hidden_states.shape
|
||||
assert self.n_shared_experts is not None, 'n_shared_experts must be set'
|
||||
# NOTE: gate has been fused with shared_experts, no more single gate call
|
||||
# and we packed router weights, shared_experts weights and down weights in a tuple
|
||||
tuple_router_shared_expert_weight = (
|
||||
self.gate.weight, self.shared_experts.gate_up_proj.weight,
|
||||
self.shared_experts.down_proj.weight)
|
||||
hidden_states = hidden_states.view(-1, orig_shape[-1])
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=tuple_router_shared_expert_weight)
|
||||
|
||||
if hasattr(final_hidden_states, 'all_reduced'):
|
||||
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
|
||||
delattr(final_hidden_states, 'all_reduced')
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = (
|
||||
self.experts.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states))
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
vllm.model_executor.models.glm4_moe.Glm4MoE = Glm4MoE
|
||||
|
||||
|
||||
def Glm4MoeAttention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
## for dummy run
|
||||
return hidden_states
|
||||
|
||||
seq_len = hidden_states.shape[-2]
|
||||
decode_seql = 512
|
||||
|
||||
if seq_len <= decode_seql:
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.attn.layer_name]
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
if kv_cache is not None:
|
||||
if hasattr(self.qkv_proj, "qweight"):
|
||||
qkv_weight = self.qkv_proj.qweight.data
|
||||
qkv_scales = self.qkv_proj.scales.data
|
||||
elif hasattr(self.qkv_proj, "weight_packed"):
|
||||
qkv_weight = self.qkv_proj.weight_packed.data
|
||||
qkv_scales = self.qkv_proj.weight_scale.data
|
||||
else:
|
||||
qkv_weight = self.qkv_proj.weight
|
||||
qkv_scales = None
|
||||
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
|
||||
hidden_states,
|
||||
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache,
|
||||
self.rotary_emb.cos_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
rotary_dim=self.rotary_emb.rotary_dim,
|
||||
bias=self.qkv_proj.bias,
|
||||
scales=qkv_scales)
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = False
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
return hidden_states
|
||||
else:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = torch_br.br_fused_split_rms_rope_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache,
|
||||
self.rotary_emb.cos_cache,
|
||||
positions,
|
||||
rotary_dim=self.rotary_emb.rotary_dim)
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = True
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
vllm.model_executor.models.glm4_moe.Glm4MoeAttention.forward = Glm4MoeAttention_forward
|
||||
|
||||
|
||||
def Glm4MoeDecoderLayer__init__(
|
||||
self,
|
||||
config: Glm4MoeConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
enable_eplb: bool = False,
|
||||
) -> None:
|
||||
super(Glm4MoeDecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
131072)
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Glm4MoeAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
head_dim=config.head_dim,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=config.attention_bias,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
use_qk_norm=config.use_qk_norm,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
and layer_idx >= config.first_k_dense_replace):
|
||||
self.mlp = Glm4MoE(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
enable_eplb=enable_eplb,
|
||||
)
|
||||
else:
|
||||
self.mlp = MergedGateUpMLPSiluL2(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
|
||||
|
||||
vllm.model_executor.models.glm4_moe.Glm4MoeDecoderLayer.__init__ = Glm4MoeDecoderLayer__init__
|
||||
|
||||
|
||||
def Glm4MoeModel_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0)
|
||||
if hidden_states is not None else hidden_states,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
vllm.model_executor.models.glm4_moe.Glm4MoeModel.forward = Glm4MoeModel_forward
|
||||
|
||||
|
||||
def Glm4MoeModel_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
for name, loaded_weight in weights:
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
|
||||
if spec_layer is not None:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if (("mlp.experts." in name) and name not in params_dict):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
torch.supa.empty_cache()
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
|
||||
# Do not modify `name` since the loop may continue here
|
||||
# Instead, create a new variable
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
|
||||
if name_mapped not in params_dict:
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or not
|
||||
# here since otherwise we may skip experts with other
|
||||
# available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
torch.supa.empty_cache()
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
# weight layout infer
|
||||
if name.find("norm.weight") != -1 or name.find(
|
||||
"e_score_correction_bias") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
if name.find("gate.weight") != -1:
|
||||
param.data = param.data.to(torch.bfloat16)
|
||||
torch.supa.empty_cache()
|
||||
loaded_params.add(name)
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
vllm.model_executor.models.glm4_moe.Glm4MoeModel.load_weights = Glm4MoeModel_load_weights
|
||||
358
vllm_br/model_executor/models/gpt_oss.py
Normal file
358
vllm_br/model_executor/models/gpt_oss.py
Normal file
@@ -0,0 +1,358 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_br
|
||||
from torch import nn
|
||||
from transformers import GptOssConfig
|
||||
|
||||
import vllm
|
||||
import vllm.model_executor.models.gpt_oss
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import (extract_layer_index,
|
||||
is_pp_missing_parameter)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import cdiv
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
class OAIAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: GptOssConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = extract_layer_index(prefix)
|
||||
self.head_dim = config.head_dim
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=config.max_position_embeddings,
|
||||
base=config.rope_theta,
|
||||
dtype=torch.float32,
|
||||
rope_scaling={
|
||||
"rope_type":
|
||||
"yarn",
|
||||
"factor":
|
||||
config.rope_scaling["factor"],
|
||||
"original_max_position_embeddings":
|
||||
config.rope_scaling["original_max_position_embeddings"],
|
||||
"beta_fast":
|
||||
config.rope_scaling["beta_fast"],
|
||||
"beta_slow":
|
||||
config.rope_scaling["beta_slow"],
|
||||
},
|
||||
is_neox_style=True,
|
||||
)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attention_sink_dtype = torch.float32
|
||||
self.sinks = torch.nn.Parameter(
|
||||
torch.empty(config.num_attention_heads // tp_size,
|
||||
dtype=attention_sink_dtype,
|
||||
requires_grad=False))
|
||||
|
||||
self.q_size = self.num_attention_heads * self.head_dim // tp_size
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = config.rope_theta
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=self.hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.num_attention_heads,
|
||||
total_num_kv_heads=self.num_key_value_heads,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.num_attention_heads * self.head_dim,
|
||||
output_size=self.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.num_local_attention_heads = config.num_attention_heads // tp_size
|
||||
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
|
||||
|
||||
# Only apply sliding window to every other layer
|
||||
sliding_window = (config.sliding_window if self.layer_idx %
|
||||
2 == 0 else None)
|
||||
self.attn = Attention(
|
||||
self.num_local_attention_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_local_key_value_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
attn_type=AttentionType.DECODER,
|
||||
prefix=f"{prefix}.attn",
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor,
|
||||
positions: torch.Tensor) -> torch.Tensor:
|
||||
qkv, _ = self.qkv(hidden_states)
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
q, k, v = torch_br.split_w_sbp_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size])
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
|
||||
dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
v = v.contiguous()
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
vllm.model_executor.models.gpt_oss.OAIAttention = OAIAttention
|
||||
|
||||
|
||||
class MLPBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
layer_idx: int,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.num_experts = config.num_local_experts
|
||||
self.experts_per_token = config.num_experts_per_tok
|
||||
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
self.router = torch.nn.Linear(config.hidden_size,
|
||||
config.num_local_experts,
|
||||
dtype=torch.bfloat16)
|
||||
assert config.intermediate_size % self.world_size == 0
|
||||
self.experts = FusedMoE(num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
apply_router_weight_on_input=False,
|
||||
has_bias=True,
|
||||
activation="swigluoai",
|
||||
is_sequence_parallel=self.is_sequence_parallel)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
final_hidden_states = self.experts(hidden_states=x.squeeze(0),
|
||||
router_logits=self.router.weight)
|
||||
|
||||
if hasattr(final_hidden_states, 'all_reduced'):
|
||||
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
|
||||
delattr(final_hidden_states, 'all_reduced')
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
vllm.model_executor.models.gpt_oss.MLPBlock = MLPBlock
|
||||
|
||||
|
||||
def GptOssModel_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
x = inputs_embeds
|
||||
else:
|
||||
x = self.get_input_embeddings(input_ids)
|
||||
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
x = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
residual = residual.unsqueeze(0)
|
||||
|
||||
x = x.unsqueeze(0)
|
||||
aux_hidden_states = []
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if i in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(x if residual is None else x + residual)
|
||||
x, residual = layer(x, positions, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
x.squeeze(0),
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else None,
|
||||
})
|
||||
x, _ = self.norm(x, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return x, aux_hidden_states
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
vllm.model_executor.models.gpt_oss.GptOssModel.forward = GptOssModel_forward
|
||||
|
||||
|
||||
def GptOssModel_load_weights_other(
|
||||
self,
|
||||
ep_rank_end: int,
|
||||
ep_rank_start: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
stacked_params_mapping: list[tuple[str, ...]],
|
||||
) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
use_ep = self.parallel_config.enable_expert_parallel
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
intermediate_size = self.config.intermediate_size
|
||||
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
|
||||
# Calculate common slicing bounds for current rank
|
||||
tp_rank_start = tp_rank * per_rank_intermediate_size
|
||||
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
|
||||
intermediate_size)
|
||||
|
||||
for name, weight in weights:
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
|
||||
if ".w13_weight" in name:
|
||||
# Handle MLP gate and up projection weights
|
||||
# Extract gate and up projection parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, :, 2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_weight" in name:
|
||||
# Handle MLP down projection weights
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
|
||||
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
|
||||
param = params_dict[name]
|
||||
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w13_bias" in name:
|
||||
# Handle MLP gate and up projection biases
|
||||
# Extract gate and up projection bias parts
|
||||
if use_ep:
|
||||
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
|
||||
|
||||
param = params_dict[name]
|
||||
param.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif ".w2_bias" in name:
|
||||
# Handle MLP down projection bias
|
||||
if use_ep:
|
||||
weight = weight[ep_rank_start:ep_rank_end, ...]
|
||||
else:
|
||||
# (only load on rank 0 to avoid duplication)
|
||||
if tp_rank != 0:
|
||||
weight.zero_()
|
||||
param = params_dict[name]
|
||||
param.copy_(weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
elif "sinks" in name:
|
||||
# Handle attention sinks (distributed across ranks)
|
||||
param = params_dict[name]
|
||||
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
|
||||
param.data.copy_(narrow_weight)
|
||||
loaded_params.add(name)
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, weight)
|
||||
else:
|
||||
weight_loader(param, weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Handle all other weights with potential renaming
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
vllm.model_executor.models.gpt_oss.GptOssModel._load_weights_other = GptOssModel_load_weights_other
|
||||
242
vllm_br/model_executor/models/intern_vit.py
Normal file
242
vllm_br/model_executor/models/intern_vit.py
Normal file
@@ -0,0 +1,242 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
# isort: off
|
||||
from vllm.model_executor.models.intern_vit import (InternMLP,
|
||||
InternVisionEmbeddings,
|
||||
InternVisionModel,
|
||||
InternVisionEncoder)
|
||||
from vllm.model_executor.models.intern_vit import InternParallelAttention
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
# isort: on
|
||||
|
||||
|
||||
@patch_to(InternVisionModel)
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
[Patch] enable data parallelism for InternVisionModel
|
||||
"""
|
||||
super(InternVisionModel, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.use_data_parallel = use_data_parallel
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(config)
|
||||
self.encoder = InternVisionEncoder(
|
||||
config=config,
|
||||
quant_config=None,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
num_dummy_heads=num_dummy_heads,
|
||||
prefix=f"{prefix}.encoder",
|
||||
use_data_parallel=use_data_parallel,
|
||||
)
|
||||
|
||||
|
||||
@patch_to(InternVisionEmbeddings)
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
if self.patch_size == 14:
|
||||
import torch_br.supa._debug as supa_debug
|
||||
|
||||
supa_debug.set_disable_zero_ws(False)
|
||||
supa_debug.set_disable_zero_output_uma(False)
|
||||
supa_debug.set_disable_zero_output_numa(False)
|
||||
supa_debug.set_disable_reorder_zero(False)
|
||||
|
||||
patch_embeds = torch_br.supa_conv2d_knxn_snxn_p0x0_fwd(
|
||||
pixel_values.to(dtype=target_dtype), self.patch_embedding.weight,
|
||||
self.patch_size, self.patch_size, 0)
|
||||
if self.patch_embedding.bias is not None:
|
||||
patch_embeds += self.patch_embedding.bias[None, :, None, None]
|
||||
supa_debug.set_disable_zero_ws(True)
|
||||
supa_debug.set_disable_zero_output_uma(True)
|
||||
supa_debug.set_disable_zero_output_numa(True)
|
||||
supa_debug.set_disable_reorder_zero(True)
|
||||
else:
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
target_dtype)) # shape = [*, channel, width, height]
|
||||
|
||||
batch_size, _, height, width = patch_embeds.shape
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1,
|
||||
-1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
if self.patch_embedding.bias is None:
|
||||
position_embedding = self._get_position_embedding(height, width)
|
||||
else:
|
||||
position_embedding = torch.cat([
|
||||
self.position_embedding[:, :1, :],
|
||||
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
|
||||
width)
|
||||
],
|
||||
dim=1)
|
||||
embeddings = embeddings + position_embedding.to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
@patch_to(InternParallelAttention)
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
num_dummy_heads: int = 0,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
) -> None:
|
||||
super(InternParallelAttention, self).__init__()
|
||||
|
||||
# [Patch] enable data parallelism
|
||||
self.use_data_parallel = True
|
||||
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(f'embed_dim must be divisible by num_heads '
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = (0
|
||||
if use_data_parallel else get_tensor_model_parallel_rank())
|
||||
|
||||
# Additional dummy heads are used to enable TP for common GPU counts.
|
||||
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
|
||||
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
|
||||
self.tp_size)
|
||||
assert self.tp_size == 1
|
||||
self.scale = self.head_dim**-0.5
|
||||
# self.qkv = QKVParallelLinear(
|
||||
# self.embed_dim,
|
||||
# self.head_dim,
|
||||
# num_dummy_heads + self.num_heads,
|
||||
# bias=config.qkv_bias,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.qkv",
|
||||
# disable_tp=use_data_parallel,
|
||||
# )
|
||||
self.qkv = torch.nn.Linear(self.embed_dim,
|
||||
3 * self.dummy_dim,
|
||||
bias=config.qkv_bias)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
self.k_norm = RMSNorm(self.dummy_dim,
|
||||
eps=config.layer_norm_eps,
|
||||
var_hidden_size=self.embed_dim)
|
||||
|
||||
# self.proj = RowParallelLinear(
|
||||
# self.dummy_dim,
|
||||
# self.embed_dim,
|
||||
# quant_config=quant_config,
|
||||
# prefix=f"{prefix}.proj",
|
||||
# disable_tp=use_data_parallel,
|
||||
# )
|
||||
self.proj = torch.nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
# self.attn = MultiHeadAttention(self.num_heads_per_partition,
|
||||
# self.head_dim, self.scale)
|
||||
|
||||
|
||||
@patch_to(InternParallelAttention)
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, N, C = x.shape
|
||||
x_tmp = []
|
||||
for i in range(B):
|
||||
qkv = self.qkv(x[i:i + 1, :]).reshape(1, N, 3, self.num_heads,
|
||||
C // self.num_heads)
|
||||
q, k, v = qkv.unbind(
|
||||
2) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
if self.qk_normalization:
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(1, N, self.num_heads,
|
||||
qkv.shape[4])
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(1, N, self.num_heads,
|
||||
qkv.shape[4])
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
attn = ((q * self.scale) @ k.transpose(-2, -1))
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x0 = attn[:, :, :, :512] @ v[:, :, :512, :]
|
||||
x1 = attn[:, :, :, 512:] @ v[:, :, 512:, :]
|
||||
|
||||
x_tmp.append((x0 + x1).transpose(1, 2).reshape(1, N, C))
|
||||
x = torch.cat(x_tmp, dim=0)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(InternMLP)
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if hidden_states.shape[0] > 1:
|
||||
output = torch_br._empty_ut_only(hidden_states.shape,
|
||||
"COLMAJOR",
|
||||
is_numa=False,
|
||||
sbp="BB",
|
||||
axis=0,
|
||||
dtype=torch.bfloat16)
|
||||
for i in range(hidden_states.shape[0]):
|
||||
hidden_states_tmp, _ = self.fc1(hidden_states[i:i + 1, :, :])
|
||||
hidden_states_tmp = self.activation_fn(hidden_states_tmp)
|
||||
hidden_states_tmp, _ = self.fc2(hidden_states_tmp)
|
||||
hidden_states_tmp += self.fc2.bias[None, None, :]
|
||||
output[i] = hidden_states_tmp[0]
|
||||
return output
|
||||
else:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
140
vllm_br/model_executor/models/internlm2.py
Normal file
140
vllm_br/model_executor/models/internlm2.py
Normal file
@@ -0,0 +1,140 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (get_pp_group, split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
|
||||
InternLM2MLP, InternLM2Model)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
def internlm2_attention_split_qkv(self, qkv: torch.Tensor):
|
||||
seq_len = qkv.shape[1]
|
||||
if self.tp_size > 1:
|
||||
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
|
||||
qkv = tensor_model_parallel_all_gather(qkv)
|
||||
qkv = torch.split(qkv, qkv_map, dim=-1)
|
||||
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
|
||||
qkv = torch.cat(qkv, dim=-1)
|
||||
|
||||
qkv = qkv.view(seq_len, self.total_num_kv_heads, self.key_value_groups + 2,
|
||||
self.head_dim)
|
||||
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
|
||||
q = q.reshape(seq_len, self.q_size * self.tp_size).unsqueeze(0)
|
||||
k = k.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
|
||||
v = v.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
|
||||
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
return q, k, v
|
||||
|
||||
|
||||
def internlm2_attention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.wqkv(hidden_states)
|
||||
q, k, v = self.split_qkv(qkv)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.wo(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
def internlm2_model_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0) if hidden_states is not None else None,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else None
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
def internlm2_mlp_init(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(InternLM2MLP, self).__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj",
|
||||
)
|
||||
self.gate_up_proj.no_need_cross = True
|
||||
self.w2 = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.w2",
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
|
||||
InternLM2Attention.split_qkv = internlm2_attention_split_qkv
|
||||
InternLM2Attention.forward = internlm2_attention_forward
|
||||
InternLM2Model.forward = internlm2_model_forward
|
||||
InternLM2MLP.__init__ = internlm2_mlp_init
|
||||
367
vllm_br/model_executor/models/llama.py
Normal file
367
vllm_br/model_executor/models/llama.py
Normal file
@@ -0,0 +1,367 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from transformers import LlamaConfig
|
||||
|
||||
import vllm.model_executor.models.llama
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.llama import (LlamaAttention,
|
||||
LlamaDecoderLayer,
|
||||
LlamaForCausalLM, LlamaModel)
|
||||
from vllm.model_executor.models.utils import (extract_layer_index,
|
||||
is_pp_missing_parameter)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_br import envs
|
||||
from ..layers.quantization.compressed_tensors.utils import (
|
||||
get_compressed_tensors_cache_scale)
|
||||
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
|
||||
|
||||
|
||||
def LlamaDecoderLayer__init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
config: Optional[LlamaConfig] = None) -> None:
|
||||
super(LlamaDecoderLayer, self).__init__()
|
||||
config = config or vllm_config.model_config.hf_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
||||
# Support internlm/internlm-7b with bias
|
||||
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
||||
config, "bias", False)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
|
||||
# determine whether use qkv merge weights
|
||||
min_w_gran = 32
|
||||
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
|
||||
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
|
||||
if is_166 or (config.num_key_value_heads *
|
||||
(self.hidden_size // config.num_attention_heads)
|
||||
>= tp_size * spc_num * min_w_gran):
|
||||
self.self_attn = AttentionSplit(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
else:
|
||||
self.self_attn = LlamaAttention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
config.num_attention_heads),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
self.mlp = MergedGateUpMLPSiluL2(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
bias=getattr(config, "mlp_bias", False),
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
loaded_params = []
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
# determine whether is qkv merge weights
|
||||
qkv_merge = False
|
||||
for key in params_dict:
|
||||
if "qkv_proj" in key:
|
||||
qkv_merge = True
|
||||
break
|
||||
if not qkv_merge and len(stacked_params_mapping) >= 3:
|
||||
stacked_params_mapping = stacked_params_mapping[3:]
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
if scale_name := get_compressed_tensors_cache_scale(name):
|
||||
# Loading kv cache scales for compressed-tensors quantization
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = loaded_weight[0]
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.append(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
# weight layout infer
|
||||
param.data = param.data + 0
|
||||
loaded_params.append(name)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
# weight layout infer
|
||||
param.data = param.data + 0
|
||||
if name.find("norm.weight") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
loaded_params.append(name)
|
||||
|
||||
return set(loaded_params)
|
||||
|
||||
|
||||
def llamamodel_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
|
||||
list[torch.Tensor]]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
residual = residual.unsqueeze(0)
|
||||
|
||||
aux_hidden_states = []
|
||||
for idx, layer in enumerate(self.layers[self.start_layer:self.end_layer]):
|
||||
if idx in self.aux_hidden_state_layers:
|
||||
aux_hidden_states.append(hidden_states + residual)
|
||||
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0)
|
||||
if hidden_states is not None else hidden_states,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else residual
|
||||
})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
def LlamaAttention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
q, k, v = torch_br.split_w_sbp_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size])
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@patch_to(LlamaAttention)
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None) -> None:
|
||||
super(LlamaAttention, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
||||
self.head_dim = getattr(config, "head_dim",
|
||||
self.hidden_size // self.total_num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
qconfig = None
|
||||
if quant_config is not None and quant_config.qkv_quantized:
|
||||
qconfig = quant_config
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size=hidden_size,
|
||||
head_size=self.head_dim,
|
||||
total_num_heads=self.total_num_heads,
|
||||
total_num_kv_heads=self.total_num_kv_heads,
|
||||
bias=bias,
|
||||
quant_config=qconfig,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
input_size=self.total_num_heads * self.head_dim,
|
||||
output_size=hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
|
||||
vllm.model_executor.models.llama.LlamaDecoderLayer.__init__ = LlamaDecoderLayer__init__
|
||||
LlamaForCausalLM.load_weights = load_weights
|
||||
LlamaModel.forward = llamamodel_forward
|
||||
LlamaAttention.forward = LlamaAttention_forward
|
||||
349
vllm_br/model_executor/models/qwen2.py
Normal file
349
vllm_br/model_executor/models/qwen2.py
Normal file
@@ -0,0 +1,349 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import gc
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from transformers import Qwen2Config
|
||||
|
||||
import vllm.model_executor.models.qwen2
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.qwen2 import (Qwen2Attention,
|
||||
Qwen2DecoderLayer, Qwen2Model)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
#import vllm.envs as envs
|
||||
from vllm_br import envs
|
||||
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
|
||||
|
||||
|
||||
def Qwen2DecoderLayer__init__(
|
||||
self,
|
||||
config: Qwen2Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Qwen2DecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
dual_chunk_attention_config = getattr(config,
|
||||
"dual_chunk_attention_config", None)
|
||||
|
||||
# By default, Qwen2 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
attention_bias = getattr(config, "attention_bias", True) or getattr(
|
||||
config, "bias", True)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
|
||||
# determine whether use qkv merge weights
|
||||
min_w_gran = 32
|
||||
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
|
||||
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
|
||||
if is_166 or (config.num_key_value_heads *
|
||||
(self.hidden_size // config.num_attention_heads)
|
||||
>= tp_size * spc_num * min_w_gran):
|
||||
self.self_attn = AttentionSplit(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
bias=attention_bias,
|
||||
)
|
||||
logger.debug('[Patch] Use AttentionSplit instead of Qwen2Attention')
|
||||
else:
|
||||
self.self_attn = Qwen2Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
dual_chunk_attention_config=dual_chunk_attention_config,
|
||||
)
|
||||
self.mlp = MergedGateUpMLPSiluL2(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
self.platform = 0
|
||||
if spc_num > 16:
|
||||
self.platform = 1
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
logger.info('[Patch] Qwen2 MLP do not merge up/gate weight')
|
||||
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
qkv_merge = False
|
||||
for key in params_dict:
|
||||
if "qkv_proj" in key:
|
||||
qkv_merge = True
|
||||
break
|
||||
if not qkv_merge and len(stacked_params_mapping) >= 3:
|
||||
stacked_params_mapping = stacked_params_mapping[3:]
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if (self.quant_config is not None
|
||||
and (scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight
|
||||
if loaded_weight.dim() == 0 else loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
if self.platform == 0:
|
||||
param.data = param.data + 0
|
||||
if name.find("norm.weight") != -1:
|
||||
if self.platform == 1:
|
||||
w_cpu = param.data.to(torch.float32).cpu()
|
||||
w_supa = torch_br._empty_ut_only(w_cpu.shape,
|
||||
dtype=w_cpu.dtype,
|
||||
is_numa=False,
|
||||
device=param.data.device,
|
||||
tensor_type="linear_bias",
|
||||
axis=0,
|
||||
sbp="BB")
|
||||
w_supa.copy_(w_cpu)
|
||||
param.data = w_supa
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if name.find("embed_tokens.weight") != -1 and self.platform == 1:
|
||||
w_shape = param.data.shape
|
||||
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
|
||||
dtype=param.data.dtype,
|
||||
is_numa=False,
|
||||
device=param.data.device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="BB")
|
||||
w_supa.copy_(param.data.cpu())
|
||||
param.data = w_supa
|
||||
|
||||
if name.find("lm_head.weight") != -1 and self.platform == 1:
|
||||
w_shape = param.data.shape
|
||||
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
|
||||
dtype=param.data.dtype,
|
||||
is_numa=False,
|
||||
device=param.data.device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SB")
|
||||
w_supa.copy_(param.data.cpu())
|
||||
param.data = w_supa
|
||||
|
||||
loaded_params.add(name)
|
||||
|
||||
# inference rope sin_cos layout
|
||||
for _, module in self.named_modules():
|
||||
rotary_emb = getattr(module, "rotary_emb", None)
|
||||
if rotary_emb is not None:
|
||||
if self.platform == 1:
|
||||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||||
w_shape = rotary_emb.cos_sin_cache.shape
|
||||
cos_sin_supa = torch_br._empty_ut_only(
|
||||
size=(w_shape[0], w_shape[1]),
|
||||
dtype=rotary_emb.cos_sin_cache.dtype,
|
||||
is_numa=False,
|
||||
device=rotary_emb.cos_sin_cache.device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="BB")
|
||||
cos_sin_supa.copy_(rotary_emb.cos_sin_cache.cpu())
|
||||
rotary_emb.cos_sin_cache = cos_sin_supa
|
||||
else:
|
||||
w_shape = rotary_emb.sin_cache.shape
|
||||
sin_supa = torch_br._empty_ut_only(
|
||||
size=(w_shape[0], w_shape[1]),
|
||||
dtype=rotary_emb.sin_cache.dtype,
|
||||
is_numa=False,
|
||||
device=rotary_emb.sin_cache.device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="BB")
|
||||
sin_supa.copy_(rotary_emb.sin_cache.cpu())
|
||||
rotary_emb.sin_cache = sin_supa
|
||||
|
||||
cos_supa = torch_br._empty_ut_only(
|
||||
size=(w_shape[0], w_shape[1]),
|
||||
dtype=rotary_emb.cos_cache.dtype,
|
||||
is_numa=False,
|
||||
device=rotary_emb.cos_cache.device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="BB")
|
||||
cos_supa.copy_(rotary_emb.cos_cache.cpu())
|
||||
rotary_emb.cos_cache = cos_supa
|
||||
else:
|
||||
if isinstance(rotary_emb, MRotaryEmbedding):
|
||||
rotary_emb.cos_sin_cache = rotary_emb.cos_sin_cache + 0
|
||||
else:
|
||||
rotary_emb.sin_cache = rotary_emb.sin_cache + 0
|
||||
rotary_emb.cos_cache = rotary_emb.cos_cache + 0
|
||||
|
||||
torch.supa.synchronize()
|
||||
gc.collect()
|
||||
torch.supa.empty_cache()
|
||||
|
||||
return loaded_params
|
||||
|
||||
|
||||
def model_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
# NOTE: supa wants 3d shape for llm
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
for layer in self.layers[self.start_layer:self.end_layer]:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0) if hidden_states is not None else None,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else None
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
# NOTE: convert back to 2D
|
||||
hidden_states = hidden_states.squeeze()
|
||||
if hidden_states.dim() == 1:
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def Qwen2Attention_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
q, k, v = torch_br.split_w_sbp_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size])
|
||||
else:
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2.Qwen2DecoderLayer.__init__ = Qwen2DecoderLayer__init__
|
||||
logger.debug('[Patch] patch Qwen2 MLP with LlaMA_MLP_SiLU_3L')
|
||||
Qwen2Model.load_weights = load_weights
|
||||
Qwen2Model.forward = model_forward
|
||||
Qwen2Attention.forward = Qwen2Attention_forward
|
||||
530
vllm_br/model_executor/models/qwen2_5_vl.py
Normal file
530
vllm_br/model_executor/models/qwen2_5_vl.py
Normal file
@@ -0,0 +1,530 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_br
|
||||
from einops import rearrange
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import parallel_state
|
||||
from vllm.distributed import utils as dist_utils
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionBlock,
|
||||
Qwen2_5_VisionMLP,
|
||||
Qwen2_5_VisionPatchMerger,
|
||||
Qwen2_5_VisionTransformer)
|
||||
from vllm.model_executor.models.qwen2_vl import apply_rotary_pos_emb_vision
|
||||
from vllm.model_executor.models.utils import cast_overflow_tensors
|
||||
from vllm.platforms import _Backend
|
||||
from vllm_br import envs
|
||||
from .br_utils import convBB, convSB
|
||||
|
||||
|
||||
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
|
||||
"""All-gather the input tensor interleavely across model parallel group."""
|
||||
import torch.distributed as dist
|
||||
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
|
||||
dist.all_gather(gathered_tensors,
|
||||
local_tensor,
|
||||
group=parallel_state.get_tp_group().device_group)
|
||||
|
||||
gathered_tensors_split = [
|
||||
torch.split(tensor, hidden_size // tp_size, -1)
|
||||
for tensor in gathered_tensors
|
||||
]
|
||||
ordered_tensors = [
|
||||
tensor for pair in zip(*gathered_tensors_split, strict=False)
|
||||
for tensor in pair
|
||||
]
|
||||
result_tensor = torch.cat(ordered_tensors, dim=-1)
|
||||
return result_tensor
|
||||
|
||||
|
||||
class Qwen2_5_VisionAttention_fit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
projection_size: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
self.tp_size = (1 if use_data_parallel else
|
||||
parallel_state.get_tensor_model_parallel_world_size())
|
||||
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
|
||||
self.hidden_size_per_attention_head = dist_utils.divide(
|
||||
projection_size, num_heads)
|
||||
self.num_attention_heads_per_partition = dist_utils.divide(
|
||||
num_heads, self.tp_size)
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
self.qkv = QKVParallelLinear(
|
||||
hidden_size=embed_dim,
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
total_num_heads=num_heads,
|
||||
total_num_kv_heads=num_heads,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv")
|
||||
self.proj = RowParallelLinear(input_size=projection_size,
|
||||
output_size=embed_dim,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.proj")
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
# [s, b, 3 * head * head_dim]
|
||||
seq_len, bs, width = qkv.shape
|
||||
qkv = qkv.reshape(-1, width)
|
||||
if self.tp_size > 1:
|
||||
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
|
||||
self.tp_size)
|
||||
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
|
||||
q, k, v = qkv.chunk(3, dim=-1)
|
||||
|
||||
# 3 * [s, b, head * head_dim]
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(dist_utils.split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
v = splitter(v)[self.tp_rank]
|
||||
|
||||
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head)
|
||||
q, k, v = (x.view(*new_shape) for x in (q, k, v))
|
||||
return q, k, v
|
||||
|
||||
def transform_qkv_shape(self,
|
||||
qkv_layer,
|
||||
cur_qkv_shape_state,
|
||||
obj_qkv_shape_state,
|
||||
obj_shape=None):
|
||||
if obj_qkv_shape_state == "bn_s_h":
|
||||
if cur_qkv_shape_state == "bn_s_h":
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "b_s_n_h":
|
||||
# [b, sq, np or nkvp, hn] --> [b, np or nkvp, sq, hn] --> [b*(np or nkvp), sq, hn]
|
||||
qkv_layer = qkv_layer.permute(0, 2, 1, 3)
|
||||
# view 4d matrix to 3d matrix, TODO: use fused_split_view here
|
||||
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
|
||||
qkv_layer.size(3)).contiguous()
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "b_n_s_h":
|
||||
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
|
||||
qkv_layer.size(3))
|
||||
return qkv_layer
|
||||
|
||||
if obj_qkv_shape_state == "b_n_s_h":
|
||||
if cur_qkv_shape_state == "b_n_s_h":
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "bn_s_h":
|
||||
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
|
||||
qkv_layer.size(1),
|
||||
qkv_layer.size(2))
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "b_s_n_h":
|
||||
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
|
||||
return qkv_layer
|
||||
|
||||
if obj_qkv_shape_state == "b_s_n_h":
|
||||
if cur_qkv_shape_state == "b_s_n_h":
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "b_n_s_h":
|
||||
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
|
||||
return qkv_layer
|
||||
if cur_qkv_shape_state == "bn_s_h":
|
||||
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
|
||||
qkv_layer.size(1),
|
||||
qkv_layer.size(2))
|
||||
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
|
||||
return qkv_layer
|
||||
|
||||
AssertionError(
|
||||
f"unsupported shape transform, ori:{cur_qkv_shape_state} obj:{obj_qkv_shape_state}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
mask: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
||||
x, _ = self.qkv(x)
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
x = convBB(x)
|
||||
seql = x.shape[-2]
|
||||
x = x.reshape(seql, 2, 3,
|
||||
-1).permute(0, 2, 1,
|
||||
3).contiguous().reshape(1, seql, -1)
|
||||
|
||||
if x.shape[0] == 1:
|
||||
x = x.permute(1, 0, 2).contiguous()
|
||||
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
|
||||
q, k, v = self.split_qkv(x)
|
||||
|
||||
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
|
||||
for x in (q, k, v))
|
||||
if rotary_pos_emb is not None:
|
||||
q = apply_rotary_pos_emb_vision(
|
||||
q,
|
||||
rotary_pos_emb,
|
||||
)
|
||||
k = apply_rotary_pos_emb_vision(
|
||||
k,
|
||||
rotary_pos_emb,
|
||||
)
|
||||
|
||||
# q, k, v: [b, s, n, h] -> reshape: [b, n, s, h] -> reshape: [b * n, s, h]
|
||||
q = q.permute(0, 2, 1, 3).contiguous()
|
||||
k = k.permute(0, 2, 1, 3).contiguous()
|
||||
v = v.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
q = self.transform_qkv_shape(q, "b_n_s_h", "bn_s_h")
|
||||
k = self.transform_qkv_shape(k, "b_n_s_h", "bn_s_h")
|
||||
v = self.transform_qkv_shape(v, "b_n_s_h", "bn_s_h")
|
||||
#TODO(qingqi), skip sueager bug, when sueager op fix the bug,remove the code
|
||||
if q.shape[1] == 8192 or q.shape[1] == 8424 or q.shape[1] == 8464:
|
||||
mask = mask.to(torch.bfloat16)
|
||||
context_layer, _ = torch_br.sueager_scaled_dot_product_attention_fwd(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
mask=mask,
|
||||
dropout_prob=0.0,
|
||||
is_causal=False,
|
||||
scale=1 / self.norm_factor,
|
||||
algorithm="FMHA",
|
||||
)
|
||||
# reshape attn out: [b*n, s, h] -> [s, b, h*n]
|
||||
context_layer = torch_br.supa_shape_transform_qkv(
|
||||
context_layer, 1, context_layer.shape[-2],
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head, False, False, None)
|
||||
if context_layer.shape[0] != 1:
|
||||
context_layer = context_layer.permute(1, 0, 2).contiguous()
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
context_layer = convSB(context_layer, -1)
|
||||
output, _ = self.proj(context_layer)
|
||||
return output
|
||||
|
||||
|
||||
def vision_block_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
mask: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
if x.shape[0] != 1:
|
||||
x = x.permute(1, 0, 2).contiguous()
|
||||
|
||||
x = x + self.attn(self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens,
|
||||
mask=mask)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2_5_VisionPatchEmbed_fit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_channels: int = 3,
|
||||
hidden_size: int = 1152,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.hidden_size = hidden_size
|
||||
self.proj = ColumnParallelLinear(in_channels * temporal_patch_size *
|
||||
patch_size * patch_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
gather_output=True,
|
||||
quant_config=quant_config,
|
||||
prefix="")
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.unsqueeze(0)
|
||||
L, _ = x.shape[-2], x.shape[-1]
|
||||
x = self.proj(x)[0].view(L, self.hidden_size)
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer)
|
||||
def gen_normal_mask(self, cu_seqlens, grid_thw, device):
|
||||
# NOTE: for mask-mock-pack, we precompute mask and store in PackedSeqParams
|
||||
seq_len = max(cu_seqlens)
|
||||
attention_mask = torch.full([1, seq_len, seq_len],
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
for i in range(1, len(cu_seqlens)):
|
||||
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
|
||||
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def vision_transformer_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
grid_thw: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
# patchify
|
||||
seq_len, _ = x.size()
|
||||
rotary_pos_emb_list = []
|
||||
window_index_list: list = []
|
||||
cu_window_seqlens_list: list = [
|
||||
torch.tensor([0], dtype=torch.int32, device="cpu")
|
||||
]
|
||||
cu_seqlens_list: list = []
|
||||
|
||||
hidden_states = x.to(device=self.device, dtype=self.dtype)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
|
||||
window_index_id = 0
|
||||
cu_window_seqlens_last = 0
|
||||
for t, h, w in grid_thw:
|
||||
t, h, w = int(t), int(h), int(w)
|
||||
llm_h = h // self.spatial_merge_size
|
||||
llm_w = w // self.spatial_merge_size
|
||||
|
||||
(
|
||||
rotary_pos_emb_thw,
|
||||
window_index_thw,
|
||||
cu_seqlens_window_thw,
|
||||
cu_seqlens_thw,
|
||||
) = self.get_rope_by_thw(t, h, w)
|
||||
|
||||
window_index_list.append(window_index_thw + window_index_id)
|
||||
window_index_id += (t * llm_h * llm_w)
|
||||
|
||||
cu_seqlens_window_thw = (cu_seqlens_window_thw +
|
||||
cu_window_seqlens_last)
|
||||
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
|
||||
cu_window_seqlens_list.append(cu_seqlens_window_thw)
|
||||
|
||||
rotary_pos_emb_list.append(rotary_pos_emb_thw)
|
||||
|
||||
cu_seqlens_list.append(cu_seqlens_thw)
|
||||
|
||||
rotary_pos_emb = torch.cat(rotary_pos_emb_list)
|
||||
window_index = torch.cat(window_index_list)
|
||||
cu_window_seqlens = torch.cat(cu_window_seqlens_list)
|
||||
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
||||
cu_seqlens = torch.cat(cu_seqlens_list)
|
||||
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
|
||||
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
||||
|
||||
# transformers
|
||||
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
|
||||
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
|
||||
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
|
||||
cu_window_seqlens)
|
||||
|
||||
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
|
||||
cu_window_seqlens = cu_window_seqlens.to(device=self.device,
|
||||
non_blocking=True)
|
||||
rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True)
|
||||
window_index = window_index.to(device=hidden_states.device,
|
||||
non_blocking=True)
|
||||
|
||||
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit,
|
||||
self.spatial_merge_unit, -1)
|
||||
hidden_states = hidden_states[window_index, :, :]
|
||||
hidden_states = hidden_states.reshape(seq_len, -1)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
|
||||
attention_mask = self.gen_normal_mask(cu_seqlens, grid_thw, x.device)
|
||||
|
||||
for layer_num, blk in enumerate(self.blocks):
|
||||
if layer_num in self.fullatt_block_indexes:
|
||||
cu_seqlens_now = cu_seqlens
|
||||
max_seqlen_now = max_seqlen_full
|
||||
seqlens_now = seqlens_full
|
||||
else:
|
||||
cu_seqlens_now = cu_window_seqlens
|
||||
max_seqlen_now = max_seqlen_window
|
||||
seqlens_now = seqlens_window
|
||||
|
||||
hidden_states = blk(hidden_states,
|
||||
cu_seqlens=cu_seqlens_now,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen_now,
|
||||
seqlens=seqlens_now,
|
||||
mask=attention_mask)
|
||||
|
||||
# For Qwen2.5-VL-3B, float16 will overflow at last block
|
||||
# for long visual tokens sequences.
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
# adapter
|
||||
hidden_states = self.merger(hidden_states).squeeze(0)
|
||||
reverse_indices = torch.argsort(window_index)
|
||||
hidden_states = hidden_states[reverse_indices, :]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def vision_transformer_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("attn.qkv.", "attn.q.", "q"),
|
||||
("attn.qkv.", "attn.k.", "k"),
|
||||
("attn.qkv.", "attn.v.", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if name == 'patch_embed.proj.weight':
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight.shape[0],
|
||||
-1).contiguous()
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
def Qwen2_5_VisionPatchMerger_forward_fit(self,
|
||||
x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.ln_q(x)
|
||||
x = x.view(-1, self.hidden_size).unsqueeze(0)
|
||||
out = self.mlp(x)
|
||||
return out
|
||||
|
||||
|
||||
def Qwen2_5_VisionMLP__init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
bias: bool = False,
|
||||
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super(Qwen2_5_VisionMLP, self).__init__()
|
||||
self.gate_proj = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_proj")
|
||||
self.up_proj = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj")
|
||||
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
disable_tp=use_data_parallel)
|
||||
self.act_fn = F.silu
|
||||
|
||||
|
||||
def Qwen2_5_VisionMLP_forward(self, x: torch.Tensor):
|
||||
x_gate, _ = self.gate_proj(x)
|
||||
x_gate = self.act_fn(x_gate)
|
||||
x_up, _ = self.up_proj(x)
|
||||
x_down, _ = self.down_proj(x_gate * x_up)
|
||||
return x_down
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionAttention = Qwen2_5_VisionAttention_fit
|
||||
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionPatchEmbed = Qwen2_5_VisionPatchEmbed_fit
|
||||
Qwen2_5_VisionBlock.forward = vision_block_forward
|
||||
Qwen2_5_VisionTransformer.forward = vision_transformer_forward
|
||||
Qwen2_5_VisionTransformer.load_weights = vision_transformer_load_weights
|
||||
Qwen2_5_VisionPatchMerger.forward = Qwen2_5_VisionPatchMerger_forward_fit
|
||||
Qwen2_5_VisionMLP.__init__ = Qwen2_5_VisionMLP__init__
|
||||
Qwen2_5_VisionMLP.forward = Qwen2_5_VisionMLP_forward
|
||||
47
vllm_br/model_executor/models/qwen2_vl.py
Normal file
47
vllm_br/model_executor/models/qwen2_vl.py
Normal file
@@ -0,0 +1,47 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLProcessingInfo)
|
||||
from vllm.multimodal.parse import ImageSize
|
||||
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
"""This function is used in Qwen2_VL, Qwen2_5_VL, patch it in qwen2_vl.py"""
|
||||
max_image_size, _ = self._get_vision_info(
|
||||
image_width=240,
|
||||
image_height=240,
|
||||
image_processor=None,
|
||||
)
|
||||
return max_image_size
|
||||
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_images = 1
|
||||
num_videos = 0
|
||||
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
image_token: str = hf_processor.image_token
|
||||
video_token: str = hf_processor.video_token
|
||||
|
||||
return image_token * num_images + video_token * num_videos
|
||||
|
||||
|
||||
Qwen2VLProcessingInfo.get_image_size_with_most_features = (
|
||||
get_image_size_with_most_features)
|
||||
Qwen2VLDummyInputsBuilder.get_dummy_text = get_dummy_text
|
||||
254
vllm_br/model_executor/models/qwen3.py
Normal file
254
vllm_br/model_executor/models/qwen3.py
Normal file
@@ -0,0 +1,254 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from transformers import Qwen3Config
|
||||
|
||||
import vllm.model_executor.models.qwen3
|
||||
from vllm.attention import AttentionType
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.qwen3 import (Qwen3Attention,
|
||||
Qwen3DecoderLayer, Qwen3Model)
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm_br.v1.attention.backends.attention_v1 import (
|
||||
SUPAFlashAttentionMetadata)
|
||||
from .qwen2 import model_forward
|
||||
from .supa_module import MergedGateUpMLPSiluL2
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.qwen3.Qwen3Attention)
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
## for dummy run
|
||||
return hidden_states
|
||||
|
||||
seq_len = hidden_states.shape[-2]
|
||||
decode_seql = 512
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.attn.layer_name]
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
if kv_cache is not None:
|
||||
if seq_len <= decode_seql:
|
||||
if hasattr(self.qkv_proj, "qweight"):
|
||||
qkv_weight = self.qkv_proj.qweight.data
|
||||
qkv_scales = self.qkv_proj.scales.data
|
||||
elif hasattr(self.qkv_proj, "weight_packed"):
|
||||
qkv_weight = self.qkv_proj.weight_packed.data
|
||||
qkv_scales = self.qkv_proj.weight_scale.data
|
||||
else:
|
||||
qkv_weight = self.qkv_proj.weight
|
||||
qkv_scales = None
|
||||
if isinstance(self.rotary_emb, MRotaryEmbedding):
|
||||
assert len(
|
||||
self.rotary_emb.mrope_section
|
||||
) == 3 and self.rotary_emb.mrope_section[
|
||||
1] == self.rotary_emb.mrope_section[
|
||||
2], "current only support mrope_section width and height are equal!"
|
||||
q, k, v = torch_br.br_qwen3_vl_prefix_attn_infer(
|
||||
hidden_states,
|
||||
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
self.rotary_emb.mrope_section[1],
|
||||
bias=self.qkv_proj.bias,
|
||||
scales=qkv_scales)
|
||||
else:
|
||||
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
|
||||
hidden_states,
|
||||
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache,
|
||||
self.rotary_emb.cos_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
bias=self.qkv_proj.bias,
|
||||
scales=qkv_scales)
|
||||
else:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if isinstance(self.rotary_emb, MRotaryEmbedding):
|
||||
assert len(
|
||||
self.rotary_emb.mrope_section
|
||||
) == 3 and self.rotary_emb.mrope_section[
|
||||
1] == self.rotary_emb.mrope_section[
|
||||
2], "current only support mrope_section width and height are equal!"
|
||||
q, k, v = torch_br.br_fused_rms_mrope_kvstore_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim, self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight, self.k_norm.weight,
|
||||
self.rotary_emb.cos_sin_cache, kv_cache, positions,
|
||||
attn_metadata.slot_mapping, attn_metadata.block_table,
|
||||
attn_metadata.query_start_loc, attn_metadata.context_lens,
|
||||
self.rotary_emb.mrope_section[1])
|
||||
else:
|
||||
q, k, v = torch_br.br_fused_rms_rope_kvstore_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim, self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight, self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
kv_cache, positions, attn_metadata.slot_mapping,
|
||||
attn_metadata.block_table, attn_metadata.query_start_loc,
|
||||
attn_metadata.context_lens)
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = False
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def Qwen3DecoderLayer__init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Qwen3DecoderLayer, self).__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 1000000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
|
||||
# By default, Qwen3 uses causal attention as it is a decoder-only model.
|
||||
# You can override the HF config with `is_causal=False` to enable
|
||||
# bidirectional attention, which is used in some embedding models
|
||||
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
|
||||
if getattr(config, "is_causal", True):
|
||||
attn_type = AttentionType.DECODER
|
||||
else:
|
||||
attn_type = AttentionType.ENCODER_ONLY
|
||||
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
qkv_bias=getattr(config, 'attention_bias', False),
|
||||
head_dim=getattr(config, 'head_dim', None),
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=rope_scaling,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
attn_type=attn_type,
|
||||
)
|
||||
self.mlp = MergedGateUpMLPSiluL2(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if (self.quant_config is not None
|
||||
and (scale_name := self.quant_config.get_cache_scale(name))):
|
||||
# Loading kv cache quantization scales
|
||||
param = params_dict[scale_name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
loaded_weight = (loaded_weight
|
||||
if loaded_weight.dim() == 0 else loaded_weight[0])
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(scale_name)
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.find("norm.weight") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
vllm.model_executor.models.qwen3.Qwen3DecoderLayer.__init__ = Qwen3DecoderLayer__init__
|
||||
logger.debug('[Patch] patch Qwen3 MLP with MergedGateUpMLPSiluL2')
|
||||
Qwen3Model.load_weights = load_weights
|
||||
Qwen3Model.forward = model_forward
|
||||
300
vllm_br/model_executor/models/qwen3_moe.py
Normal file
300
vllm_br/model_executor/models/qwen3_moe.py
Normal file
@@ -0,0 +1,300 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.distributed import get_pp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeModel
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_br.v1.attention.backends.attention_v1 import (
|
||||
SUPAFlashAttentionMetadata)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock)
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||
orig_shape = hidden_states.shape
|
||||
if len(hidden_states.shape) == 3:
|
||||
hidden_states = hidden_states.squeeze(0)
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=(self.gate.weight, None,
|
||||
None))
|
||||
if hasattr(final_hidden_states, 'all_reduced'):
|
||||
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
|
||||
delattr(final_hidden_states, 'all_reduced')
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.qwen3_moe.Qwen3MoeAttention)
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
## for dummy run
|
||||
return hidden_states
|
||||
|
||||
seq_len = hidden_states.shape[-2]
|
||||
decode_seql = 512
|
||||
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.attn.layer_name]
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
if kv_cache is not None:
|
||||
if seq_len <= decode_seql:
|
||||
if hasattr(self.qkv_proj, "qweight"):
|
||||
qkv_weight = self.qkv_proj.qweight.data
|
||||
qkv_scales = self.qkv_proj.scales.data
|
||||
elif hasattr(self.qkv_proj, "weight_packed"):
|
||||
qkv_weight = self.qkv_proj.weight_packed.data
|
||||
qkv_scales = self.qkv_proj.weight_scale.data
|
||||
else:
|
||||
qkv_weight = self.qkv_proj.weight
|
||||
qkv_scales = None
|
||||
if isinstance(self.rotary_emb, MRotaryEmbedding):
|
||||
assert len(
|
||||
self.rotary_emb.mrope_section
|
||||
) == 3 and self.rotary_emb.mrope_section[
|
||||
1] == self.rotary_emb.mrope_section[
|
||||
2], "current only support mrope_section width and height are equal!"
|
||||
q, k, v = torch_br.br_qwen3_vl_prefix_attn_infer(
|
||||
hidden_states,
|
||||
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.cos_sin_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
self.rotary_emb.mrope_section[1],
|
||||
bias=self.qkv_proj.bias,
|
||||
scales=qkv_scales)
|
||||
else:
|
||||
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
|
||||
hidden_states,
|
||||
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim,
|
||||
self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight,
|
||||
self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache,
|
||||
self.rotary_emb.cos_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
bias=self.qkv_proj.bias,
|
||||
scales=qkv_scales)
|
||||
else:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
if isinstance(self.rotary_emb, MRotaryEmbedding):
|
||||
assert len(
|
||||
self.rotary_emb.mrope_section
|
||||
) == 3 and self.rotary_emb.mrope_section[
|
||||
1] == self.rotary_emb.mrope_section[
|
||||
2], "current only support mrope_section width and height are equal!"
|
||||
q, k, v = torch_br.br_fused_rms_mrope_kvstore_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim, self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight, self.k_norm.weight,
|
||||
self.rotary_emb.cos_sin_cache, kv_cache, positions,
|
||||
attn_metadata.slot_mapping, attn_metadata.block_table,
|
||||
attn_metadata.query_start_loc, attn_metadata.context_lens,
|
||||
self.rotary_emb.mrope_section[1])
|
||||
else:
|
||||
q, k, v = torch_br.br_fused_rms_rope_kvstore_infer(
|
||||
qkv, [self.q_size, self.kv_size, self.kv_size],
|
||||
self.head_dim, self.q_norm.variance_epsilon,
|
||||
self.q_norm.weight, self.k_norm.weight,
|
||||
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
|
||||
kv_cache, positions, attn_metadata.slot_mapping,
|
||||
attn_metadata.block_table, attn_metadata.query_start_loc,
|
||||
attn_metadata.context_lens)
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = False
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def model_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
if len(hidden_states.shape) == 2:
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.squeeze(0) if hidden_states is not None else None,
|
||||
"residual":
|
||||
residual.squeeze(0) if residual is not None else None
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
# NOTE: convert back to 2D
|
||||
hidden_states = hidden_states.squeeze()
|
||||
if hidden_states.dim() == 1:
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
Qwen3MoeModel.forward = model_forward
|
||||
|
||||
|
||||
def Qwen3MoeModel_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts)
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param,
|
||||
loaded_weight,
|
||||
name,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_kv_scale_name,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.find("norm.weight") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
Qwen3MoeModel.load_weights = Qwen3MoeModel_load_weights
|
||||
207
vllm_br/model_executor/models/qwen3_vl.py
Normal file
207
vllm_br/model_executor/models/qwen3_vl.py
Normal file
@@ -0,0 +1,207 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock,
|
||||
Qwen3_VisionPatchEmbed,
|
||||
Qwen3_VisionTransformer,
|
||||
Qwen3LLMModel)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm_br import envs
|
||||
from .br_utils import convBB
|
||||
|
||||
|
||||
def Qwen3_VisionPatchEmbed__init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
temporal_patch_size: int = 2,
|
||||
in_channels: int = 3,
|
||||
hidden_size: int = 1152,
|
||||
) -> None:
|
||||
super(Qwen3_VisionPatchEmbed, self).__init__()
|
||||
self.patch_size = patch_size
|
||||
self.temporal_patch_size = temporal_patch_size
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.proj = ReplicatedLinear(in_channels * temporal_patch_size *
|
||||
patch_size * patch_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
prefix="")
|
||||
|
||||
|
||||
Qwen3_VisionPatchEmbed.__init__ = Qwen3_VisionPatchEmbed__init__
|
||||
|
||||
|
||||
def Qwen3_VisionPatchEmbed_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.unsqueeze(0)
|
||||
L, _ = x.shape[-2], x.shape[-1]
|
||||
x = self.proj(x)[0].view(L, self.hidden_size)
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
x = convBB(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
Qwen3_VisionPatchEmbed.forward = Qwen3_VisionPatchEmbed_forward
|
||||
|
||||
|
||||
def Qwen3_VisionBlock_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rotary_pos_emb: torch.Tensor,
|
||||
max_seqlen: Optional[int] = None, # Only used for Flash Attention
|
||||
seqlens: Optional[list[int]] = None, # Only used for xFormers
|
||||
) -> torch.Tensor:
|
||||
if x.shape[0] != 1:
|
||||
x = x.permute(1, 0, 2).contiguous()
|
||||
|
||||
x = x + self.attn(self.norm1(x),
|
||||
cu_seqlens=cu_seqlens,
|
||||
rotary_pos_emb=rotary_pos_emb,
|
||||
max_seqlen=max_seqlen,
|
||||
seqlens=seqlens)
|
||||
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
Qwen3_VisionBlock.forward = Qwen3_VisionBlock_forward
|
||||
|
||||
|
||||
def Qwen3_VisionTransformer_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("attn.qkv.", "attn.q.", "q"),
|
||||
("attn.qkv.", "attn.k.", "k"),
|
||||
("attn.qkv.", "attn.v.", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if name == 'patch_embed.proj.weight':
|
||||
loaded_weight = loaded_weight.reshape(loaded_weight.shape[0],
|
||||
-1).contiguous()
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.find("norm.weight") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
Qwen3_VisionTransformer.load_weights = Qwen3_VisionTransformer_load_weights
|
||||
|
||||
|
||||
def Qwen3LLMModel_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
# args for deepstack
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
residual = residual.unsqueeze(0) if residual is not None else None
|
||||
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None and \
|
||||
layer_idx in range(0, len(deepstack_input_embeds)):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[
|
||||
f"deepstack_input_embeds_{layer_idx}"].to(
|
||||
hidden_states.device).unsqueeze(0)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.unsqueeze(0),
|
||||
"residual":
|
||||
residual.unsqueeze(0) if residual is not None else None
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
Qwen3LLMModel.forward = Qwen3LLMModel_forward
|
||||
258
vllm_br/model_executor/models/qwen3_vl_moe.py
Normal file
258
vllm_br/model_executor/models/qwen3_vl_moe.py
Normal file
@@ -0,0 +1,258 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL MOE model compatible with HuggingFace weights."""
|
||||
import typing
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader, maybe_remap_kv_scale_name)
|
||||
from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMModel
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
|
||||
def Qwen3MoeLLMModel_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
deepstack_input_embeds: Optional[IntermediateTensors] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(0)
|
||||
residual = residual.unsqueeze(0) if residual is not None else None
|
||||
|
||||
for layer_idx, layer in enumerate(
|
||||
self.layers[self.start_layer:self.end_layer]):
|
||||
layer_idx = layer_idx + self.start_layer
|
||||
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
|
||||
if deepstack_input_embeds is not None and \
|
||||
layer_idx in range(0, len(deepstack_input_embeds)):
|
||||
hidden_states = hidden_states + deepstack_input_embeds[
|
||||
f"deepstack_input_embeds_{layer_idx}"].to(
|
||||
hidden_states.device).unsqueeze(0)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({
|
||||
"hidden_states":
|
||||
hidden_states.unsqueeze(0),
|
||||
"residual":
|
||||
residual.unsqueeze(0) if residual is not None else None
|
||||
})
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states.squeeze(0)
|
||||
|
||||
|
||||
Qwen3MoeLLMModel.forward = Qwen3MoeLLMModel_forward
|
||||
|
||||
|
||||
def Qwen3MoeLLMModel_load_weights(
|
||||
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale",
|
||||
"_v_scale", ".weight_scale", "_weight_scale",
|
||||
".input_scale", "_input_scale")
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: set[str] = set()
|
||||
expert_params_mapping = self.get_expert_mapping()
|
||||
is_fused_expert = False
|
||||
fused_expert_params_mapping = [
|
||||
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
|
||||
("experts.w2_weight", "experts.down_proj", 0, "w2"),
|
||||
]
|
||||
num_experts = self.config.num_experts
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if ("experts.gate_up_proj" in name or "experts.down_proj" in name):
|
||||
is_fused_expert = True
|
||||
expert_params_mapping = fused_expert_params_mapping
|
||||
|
||||
# Skip non-stacked layers and experts (experts handled below).
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if "mlp.experts" in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
if name.endswith("scale"):
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if weight_loader == default_weight_loader:
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
is_expert_weight = False
|
||||
for mapping in expert_params_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# Anyway, this is an expert weight and should not be
|
||||
# attempted to load as other weights later
|
||||
is_expert_weight = True
|
||||
name_mapped = name.replace(weight_name, param_name)
|
||||
if is_pp_missing_parameter(name_mapped, self):
|
||||
continue
|
||||
if is_fused_expert:
|
||||
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
|
||||
if "experts.gate_up_proj" in name:
|
||||
loaded_weight = loaded_weight.chunk(2, dim=-2)
|
||||
success_w1 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[0], "w1",
|
||||
num_experts)
|
||||
success_w3 = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight[1], "w3",
|
||||
num_experts)
|
||||
success = success_w1 and success_w3
|
||||
else:
|
||||
# down_proj
|
||||
success = self.load_fused_expert_weights(
|
||||
name_mapped, params_dict, loaded_weight, shard_id,
|
||||
num_experts)
|
||||
else:
|
||||
# Skip loading extra parameters for GPTQ/modelopt models
|
||||
if name_mapped.endswith(
|
||||
ignore_suffixes
|
||||
) and name_mapped not in params_dict:
|
||||
continue
|
||||
param = params_dict[name_mapped]
|
||||
# We should ask the weight loader to return success or
|
||||
# not here since otherwise we may skip experts with
|
||||
# other available replicas.
|
||||
weight_loader = typing.cast(Callable[..., bool],
|
||||
param.weight_loader)
|
||||
success = weight_loader(param,
|
||||
loaded_weight,
|
||||
name_mapped,
|
||||
shard_id=shard_id,
|
||||
expert_id=expert_id,
|
||||
return_success=True)
|
||||
if success:
|
||||
name = name_mapped
|
||||
break
|
||||
else:
|
||||
if is_expert_weight:
|
||||
# We've checked that this is an expert weight
|
||||
# However it's not mapped locally to this rank
|
||||
# So we simply skip it
|
||||
continue
|
||||
# Skip loading extra parameters for GPTQ/modelopt models.
|
||||
if name.endswith(ignore_suffixes) and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
# logger.warning_once(
|
||||
# "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
|
||||
# name,
|
||||
# remapped_kv_scale_name,
|
||||
# )
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
if name == 'patch_embed.proj.weight':
|
||||
loaded_weight = loaded_weight.reshape(
|
||||
loaded_weight.shape[0], -1).contiguous()
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.find("norm.weight") != -1:
|
||||
param.data = param.data.to(torch.float32)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
Qwen3MoeLLMModel.load_weights = Qwen3MoeLLMModel_load_weights
|
||||
27
vllm_br/model_executor/models/registry.py
Normal file
27
vllm_br/model_executor/models/registry.py
Normal file
@@ -0,0 +1,27 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from vllm import ModelRegistry
|
||||
from vllm.model_executor.models.registry import _MULTIMODAL_MODELS
|
||||
|
||||
#from .glm4_1v import Glm4vForConditionalGeneration
|
||||
|
||||
_MULTIMODAL_MODELS["Glm4vForConditionalGeneration"] = (
|
||||
"glm4_1v", "Glm4vForConditionalGeneration")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Glm4vForConditionalGeneration",
|
||||
"vllm_br.model_executor.models.glm4_1v:Glm4vForConditionalGeneration")
|
||||
89
vllm_br/model_executor/models/roberta.py
Normal file
89
vllm_br/model_executor/models/roberta.py
Normal file
@@ -0,0 +1,89 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
# Adapted from transformers
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
import vllm
|
||||
from vllm.model_executor.models.roberta import (
|
||||
create_position_ids_from_input_ids)
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.roberta.RobertaClassificationHead)
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[0, :] # take <s> token (equiv. to [CLS])
|
||||
x = x.unsqueeze(0) # add batch dimension
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.out_proj(x)
|
||||
x = x.squeeze(0) # remove batch dimension
|
||||
return x
|
||||
|
||||
|
||||
@patch_to(vllm.model_executor.models.roberta.RobertaEmbedding)
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
input_ids = input_ids.squeeze(0) # notice here input_ids is 2-dim tensor
|
||||
input_shape = input_ids.size()
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
# Replace position ids because in RoBERTa models
|
||||
# they have to start at padding_idx + 1 and ignore
|
||||
# existing padding tokens
|
||||
# References:
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
||||
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
||||
pos_list = []
|
||||
token_list = []
|
||||
offset = 0
|
||||
for seq_len in seq_lens:
|
||||
pos_list.append(position_ids[offset:offset + seq_len])
|
||||
token_list.append(input_ids[offset:offset + seq_len])
|
||||
offset += seq_len
|
||||
|
||||
new_pos_list = []
|
||||
for positions, tokens in zip(pos_list, token_list, strict=False):
|
||||
# Verify assumption that incoming position are
|
||||
# always a sequence from 0 to N.
|
||||
expected_pos = torch.arange(positions.size()[0],
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
assert torch.equal(positions, expected_pos)
|
||||
new_pos_list.append(
|
||||
create_position_ids_from_input_ids(tokens, self.padding_idx))
|
||||
position_ids = torch.cat(new_pos_list)
|
||||
|
||||
# Position embeddings.
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros(input_shape,
|
||||
dtype=torch.long,
|
||||
device=inputs_embeds.device)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings.unsqueeze(0) # add batch dimension for BR attention
|
||||
25
vllm_br/model_executor/models/supa_module/__init__.py
Normal file
25
vllm_br/model_executor/models/supa_module/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from .attention import AttentionSplit
|
||||
from .mla import SupaMLAModules, SupaMultiHeadLatentAttention
|
||||
from .mlp import LlamaMlpSiluL3, MergedGateUpMLPSiluL2
|
||||
from .moe import DeepseekV2MoE
|
||||
|
||||
__all__ = [
|
||||
'LlamaMlpSiluL3', 'AttentionSplit', 'MergedGateUpMLPSiluL2',
|
||||
'DeepseekV2MoE', 'SupaMLAModules', 'SupaMultiHeadLatentAttention'
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
206
vllm_br/model_executor/models/supa_module/attention.py
Normal file
206
vllm_br/model_executor/models/supa_module/attention.py
Normal file
@@ -0,0 +1,206 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from torch import nn
|
||||
from torch_br.supa.profiler_kineto import record_function
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding,
|
||||
get_rope)
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
|
||||
class AttentionSplit(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: int = 10000,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
rope_scaling: Optional[Tuple] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
prefix: str = "",
|
||||
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
qconfig = None
|
||||
if quant_config is not None and quant_config.qkv_quantized:
|
||||
qconfig = quant_config
|
||||
|
||||
self.q_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=self.q_size * tp_size,
|
||||
bias=bias,
|
||||
quant_config=qconfig,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
self.k_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=self.kv_size * tp_size,
|
||||
bias=bias,
|
||||
quant_config=qconfig,
|
||||
prefix=f"{prefix}.k_proj")
|
||||
self.v_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=self.kv_size * tp_size,
|
||||
bias=bias,
|
||||
quant_config=qconfig,
|
||||
prefix=f"{prefix}.v_proj")
|
||||
|
||||
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
attn_type=attn_type,
|
||||
prefix=f"{prefix}.attn",
|
||||
**{
|
||||
"layer_idx": extract_layer_index(prefix),
|
||||
"dual_chunk_attention_config": dual_chunk_attention_config,
|
||||
} if dual_chunk_attention_config else {})
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if attn_metadata is None:
|
||||
## for dummy run
|
||||
return hidden_states
|
||||
|
||||
seq_len = hidden_states.shape[-2]
|
||||
decode_seql = 512
|
||||
|
||||
# numa weight and not use mrope (qwen-vl)
|
||||
if ((hasattr(self.q_proj, "qweight")
|
||||
and len(self.q_proj.qweight.shape) == 3) or
|
||||
(hasattr(self.q_proj, "weight")
|
||||
and len(self.q_proj.weight.shape) == 3)) and not isinstance(
|
||||
self.rotary_emb, MRotaryEmbedding) and seq_len <= decode_seql:
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[self.attn.layer_name]
|
||||
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
|
||||
if kv_cache is not None:
|
||||
with record_function('attention qkv_rope'):
|
||||
# int8 weight version
|
||||
q_weight = self.q_proj.qweight if hasattr(
|
||||
self.q_proj, "qweight") else self.q_proj.weight
|
||||
k_weight = self.k_proj.qweight if hasattr(
|
||||
self.k_proj, "qweight") else self.k_proj.weight
|
||||
v_weight = self.v_proj.qweight if hasattr(
|
||||
self.v_proj, "qweight") else self.v_proj.weight
|
||||
q_scale = self.q_proj.scales if hasattr(
|
||||
self.q_proj, "scales") else None
|
||||
k_scale = self.k_proj.scales if hasattr(
|
||||
self.k_proj, "scales") else None
|
||||
v_scale = self.v_proj.scales if hasattr(
|
||||
self.v_proj, "scales") else None
|
||||
q_bias = self.q_proj.bias if hasattr(self.q_proj,
|
||||
"bias") else None
|
||||
k_bias = self.k_proj.bias if hasattr(self.k_proj,
|
||||
"bias") else None
|
||||
v_bias = self.v_proj.bias if hasattr(self.v_proj,
|
||||
"bias") else None
|
||||
q, k, v = torch_br.supa_qkv_rope_decode_infer(
|
||||
hidden_states,
|
||||
q_weight,
|
||||
k_weight,
|
||||
v_weight,
|
||||
self.rotary_emb.sin_cache,
|
||||
self.rotary_emb.cos_cache,
|
||||
kv_cache,
|
||||
positions,
|
||||
attn_metadata.slot_mapping,
|
||||
self.rotary_emb.head_size,
|
||||
self.q_size,
|
||||
self.kv_size,
|
||||
q_scale=q_scale,
|
||||
k_scale=k_scale,
|
||||
v_scale=v_scale,
|
||||
q_bias=q_bias,
|
||||
k_bias=k_bias,
|
||||
v_bias=v_bias)
|
||||
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = False
|
||||
with record_function('attention'):
|
||||
attn_output = self.attn(q, k, v)
|
||||
with record_function('attention o_proj'):
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
return output
|
||||
else:
|
||||
return hidden_states
|
||||
else:
|
||||
# uma weight or use mrope (qwen-vl)
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
if hasattr(attn_metadata, 'do_cache'):
|
||||
attn_metadata.do_cache = True
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
210
vllm_br/model_executor/models/supa_module/mla.py
Normal file
210
vllm_br/model_executor/models/supa_module/mla.py
Normal file
@@ -0,0 +1,210 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.mla import MLAModules
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class SupaMLAModules(MLAModules):
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
|
||||
|
||||
@CustomOp.register("supa_multi_head_latent_attention")
|
||||
class SupaMultiHeadLatentAttention(CustomOp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
scale: float,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
mla_modules: MLAModules,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.num_heads = num_heads
|
||||
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
|
||||
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
|
||||
self.q_a_layernorm = mla_modules.q_a_layernorm
|
||||
self.q_b_proj = mla_modules.q_b_proj
|
||||
self.q_proj = mla_modules.q_proj
|
||||
self.kv_a_layernorm = mla_modules.kv_a_layernorm
|
||||
self.kv_b_proj = mla_modules.kv_b_proj
|
||||
self.rotary_emb = mla_modules.rotary_emb
|
||||
self.o_proj = mla_modules.o_proj
|
||||
self.indexer = mla_modules.indexer
|
||||
self.is_sparse = mla_modules.is_sparse
|
||||
self.q_a_proj = mla_modules.q_a_proj
|
||||
|
||||
if self.indexer is not None:
|
||||
assert hasattr(self.indexer, "topk_tokens")
|
||||
self.topk_tokens = self.indexer.topk_tokens
|
||||
self.topk_indices_buffer = mla_modules.topk_indices_buffer
|
||||
|
||||
# In the MLA backend, kv_cache includes both k_c and
|
||||
# pe (i.e. decoupled position embeddings). In particular,
|
||||
# the concat_and_cache_mla op requires
|
||||
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
||||
# i.e.
|
||||
# kv_lora_rank + qk_rope_head_dim == head_size
|
||||
if self.is_sparse:
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
)
|
||||
else:
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scale,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sparse=mla_modules.is_sparse,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
indexer=self.indexer,
|
||||
# BIREN args for fused MLA
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj
|
||||
if self.q_lora_rank is None else self.q_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
q_a_proj=None if self.q_lora_rank is None else self.q_a_proj,
|
||||
q_a_layernorm=None
|
||||
if self.q_lora_rank is None else self.q_a_layernorm,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
q_c = None
|
||||
kv_lora = None
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
assert self.fused_qkv_a_proj is not None, \
|
||||
"fused_qkv_a_proj is required when q_lora_rank is not None"
|
||||
assert self.q_a_layernorm is not None, \
|
||||
"q_a_layernorm is required when q_lora_rank is not None"
|
||||
assert self.q_b_proj is not None, \
|
||||
"q_b_proj is required when q_lora_rank is not None"
|
||||
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
|
||||
q_c, kv_lora = qkv_lora.split(
|
||||
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
|
||||
dim=-1,
|
||||
)
|
||||
q_c = self.q_a_layernorm(q_c)
|
||||
q = self.q_b_proj(q_c)[0].view(-1,
|
||||
self.num_heads * self.qk_head_dim)
|
||||
else:
|
||||
assert self.kv_a_proj_with_mqa is not None, \
|
||||
"kv_a_proj_with_mqa is required when q_lora_rank is None"
|
||||
assert self.q_proj is not None, \
|
||||
"q_proj is required when q_lora_rank is None"
|
||||
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
q = self.q_proj(hidden_states)[0]
|
||||
|
||||
kv_lora = kv_lora.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c)
|
||||
|
||||
q = q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
# Add head dim of 1 to k_pe
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
|
||||
positions, q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
if self.indexer and self.is_sparse:
|
||||
_topk_indices = self.indexer(hidden_states, q_c, positions,
|
||||
self.rotary_emb)
|
||||
|
||||
seq_len = hidden_states.shape[1]
|
||||
attn_out = self.mla_attn(q,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=(seq_len, self.num_heads *
|
||||
self.v_head_dim))
|
||||
return self.o_proj(attn_out)[0].unsqueeze(0)
|
||||
|
||||
def forward_supa(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return self.mla_attn(hidden_states,
|
||||
positions,
|
||||
hidden_states,
|
||||
output_shape=hidden_states.shape)
|
||||
|
||||
def forward_oot(self, *args, is_ds_v32: Optional[int], **kwargs):
|
||||
if is_ds_v32:
|
||||
return self.forward_native(*args, **kwargs)
|
||||
else:
|
||||
return self.forward_supa(*args, **kwargs)
|
||||
170
vllm_br/model_executor/models/supa_module/mlp.py
Normal file
170
vllm_br/model_executor/models/supa_module/mlp.py
Normal file
@@ -0,0 +1,170 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch_br
|
||||
from torch import nn
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import (get_pp_group,
|
||||
get_tensor_model_parallel_rank)
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm_br import envs
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
|
||||
|
||||
class LlamaMlpSiluL3(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_proj")
|
||||
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
gate, _ = self.gate_proj(x)
|
||||
up, _ = self.up_proj(x)
|
||||
x = torch_br.supa_silumul(gate, up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedGateUpMLPSiluL2(nn.Module):
|
||||
"""
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.intermediate_size = intermediate_size
|
||||
self.prefix = prefix
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.gate_up_proj.has_cross_weight = True
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x):
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
if "shared_experts" not in self.prefix:
|
||||
quant_flag = hasattr(self.gate_up_proj, "qweight")
|
||||
hidden_size = x.shape[-1]
|
||||
seq_len = x.shape[-2]
|
||||
gu_weight = self.gate_up_proj.qweight if quant_flag else self.gate_up_proj.weight
|
||||
gu_scales = self.gate_up_proj.scales if quant_flag else None
|
||||
gate_up_output = torch_br.br_fused_mlp_infer(
|
||||
x, [gu_weight],
|
||||
output_w=self.intermediate_size // self.tp_size,
|
||||
scales=[gu_scales] if gu_scales is not None else None,
|
||||
activation_mode="act_swiglu")
|
||||
|
||||
down_weight = self.down_proj.qweight if quant_flag else self.down_proj.weight
|
||||
down_scales = self.down_proj.scales if quant_flag else None
|
||||
|
||||
# bypass tp8 and tp4pp2 allreduce
|
||||
pp_size = get_pp_group().world_size
|
||||
all_rank = self.tp_size * pp_size
|
||||
support_types = ((16, 4), (32, 2), (32, 4))
|
||||
if all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and \
|
||||
(envs.VLLM_BR_DEVICE_SPC_NUM, self.tp_size) in support_types:
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
global_rank = get_tp_group().rank
|
||||
rank_i = global_rank % self.tp_size
|
||||
assert rank_i == tp_rank
|
||||
down_output = torch_br.supa_fused_linear_allreduce_opt(
|
||||
gate_up_output,
|
||||
down_weight,
|
||||
hidden_size,
|
||||
tp_rank,
|
||||
self.tp_size,
|
||||
global_rank,
|
||||
0,
|
||||
scales=down_scales)
|
||||
|
||||
return down_output
|
||||
else:
|
||||
down_output = torch_br.br_fused_mlp_infer(
|
||||
gate_up_output, [down_weight],
|
||||
output_w=hidden_size,
|
||||
scales=[down_scales] if down_scales is not None else None)
|
||||
|
||||
if self.tp_size > 1:
|
||||
out = down_output
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and out.shape[
|
||||
1] <= 32:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
output = torch_br.supa_allreduce_pcie_infer(
|
||||
out, tp_rank, self.tp_size, self.grandparent_pid)
|
||||
else:
|
||||
output = tensor_model_parallel_all_reduce(out)
|
||||
return output
|
||||
else:
|
||||
return down_output
|
||||
else:
|
||||
return self.gate_up_proj.weight, self.down_proj.weight
|
||||
116
vllm_br/model_executor/models/supa_module/moe.py
Normal file
116
vllm_br/model_executor/models/supa_module/moe.py
Normal file
@@ -0,0 +1,116 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2MLP,
|
||||
ParallelConfig)
|
||||
from vllm_br import envs
|
||||
from vllm_br.utils import get_grandparent_pid
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.n_shared_experts = config.n_shared_experts
|
||||
self.static_moe_decoder_max_len = 512
|
||||
|
||||
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
|
||||
|
||||
if config.hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate")
|
||||
if config.topk_method == "noaux_tc":
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, device="cpu"))
|
||||
else:
|
||||
self.gate.e_score_correction_bias = None
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.n_routed_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
e_score_correction_bias=self.gate.e_score_correction_bias)
|
||||
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
||||
self, "grandparent_pid"):
|
||||
self.grandparent_pid = get_grandparent_pid()
|
||||
orig_shape = hidden_states.shape
|
||||
assert self.n_shared_experts is not None, 'n_shared_experts must be set'
|
||||
# NOTE: gate has been fused with shared_experts, no more single gate call
|
||||
# and we packed router weights, shared_experts weights and down weights in a tuple
|
||||
tuple_router_shared_expert_weight = (
|
||||
self.gate.weight, self.shared_experts.gate_up_proj.weight,
|
||||
self.shared_experts.down_proj.weight)
|
||||
hidden_states = hidden_states.view(-1, orig_shape[-1])
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=tuple_router_shared_expert_weight)
|
||||
|
||||
if hasattr(final_hidden_states, 'all_reduced'):
|
||||
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
|
||||
delattr(final_hidden_states, 'all_reduced')
|
||||
elif self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
86
vllm_br/model_executor/models/utils.py
Normal file
86
vllm_br/model_executor/models/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
|
||||
# Copyright 2025 The vLLM team.
|
||||
# Copyright 2025 The Qwen Team.
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
|
||||
import torch
|
||||
|
||||
import vllm
|
||||
from vllm.model_executor.models.utils import (_embedding_count_expression,
|
||||
_flatten_embeddings)
|
||||
from vllm.multimodal import NestedTensors
|
||||
|
||||
|
||||
def _merge_multimodal_embeddings_fit(
|
||||
inputs_embeds: torch.Tensor,
|
||||
is_multimodal: torch.Tensor,
|
||||
multimodal_embeddings: NestedTensors,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
|
||||
positions in ``inputs_embeds`` corresponding to placeholder tokens in
|
||||
``input_ids``.
|
||||
|
||||
Note:
|
||||
This updates ``inputs_embeds`` in place.
|
||||
"""
|
||||
flattened = _flatten_embeddings(multimodal_embeddings)
|
||||
try:
|
||||
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
|
||||
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
|
||||
# flattened.to(dtype=inputs_embeds.dtype))
|
||||
inputs_embeds[is_multimodal] = flattened
|
||||
except RuntimeError as e:
|
||||
num_expected_tokens = is_multimodal.sum().item()
|
||||
assert isinstance(num_expected_tokens, int)
|
||||
|
||||
if flattened.shape[0] != num_expected_tokens:
|
||||
expr = _embedding_count_expression(multimodal_embeddings)
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {flattened.shape[0]} "
|
||||
f"multimodal tokens to {num_expected_tokens} placeholders"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError("Error during masked scatter operation") from e
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings_fit
|
||||
Reference in New Issue
Block a user