first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import (bert, bert_with_rope, chatglm, clip, config, deepseek_mtp,
deepseek_v2, glm4, glm4_1v, glm4_moe, gpt_oss, intern_vit,
internlm2, llama, qwen2, qwen2_5_vl, qwen2_vl, qwen3, qwen3_moe,
qwen3_vl, qwen3_vl_moe, registry, utils)
__all__ = [
"bert_with_rope",
"bert",
"chatglm",
"clip",
"config",
"deepseek_mtp",
"deepseek_v2",
"glm4_1v",
"glm4_moe",
"glm4",
"gpt_oss",
"intern_vit",
"internlm2",
"llama",
"qwen2_5_vl",
"qwen2_vl",
"qwen2",
"qwen3_moe",
"qwen3",
"qwen3_vl",
"qwen3_vl_moe",
"registry",
"utils",
]

View File

@@ -0,0 +1,42 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
from fastcore.basics import patch_to
from vllm.model_executor.models.bert import BertModel
from vllm.sequence import IntermediateTensors
@patch_to(BertModel)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
input_ids = input_ids.unsqueeze(
0
) # Note: set input batch size (bs) to 1 here; otherwise attention module will raise an error.
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=positions)
hidden_states = self.encoder(hidden_states).squeeze(0)
return hidden_states

View File

@@ -0,0 +1,42 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
from fastcore.basics import patch_to
from vllm.model_executor.models.bert_with_rope import BertWithRope
from vllm.sequence import IntermediateTensors
@patch_to(BertWithRope)
def forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
if input_ids is None:
raise ValueError("input_ids must be provided.")
input_ids = input_ids.unsqueeze(0)
hidden_states = self.embeddings(input_ids=input_ids,
token_type_ids=token_type_ids)
return self.encoder(positions, hidden_states).squeeze(0)

View File

@@ -0,0 +1,48 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch_br
def convBB(input_tensor):
o_layout = torch_br.supa._debug.get_tensor_info(input_tensor)[0]["layout"]
o_layout = o_layout.lower()
input_tensor_supa = torch_br._empty_ut_only(
size=input_tensor.shape,
dtype=input_tensor.dtype,
is_numa=False,
device=input_tensor.device,
tensor_type=o_layout,
sbp="BB",
)
input_tensor_supa.copy_(input_tensor)
return input_tensor_supa
def convSB(input_tensor, axis: int):
o_layout = torch_br.supa._debug.get_tensor_info(input_tensor)[0]["layout"]
o_layout = o_layout.lower()
input_tensor_supa = torch_br._empty_ut_only(
size=input_tensor.shape,
dtype=input_tensor.dtype,
is_numa=False,
device=input_tensor.device,
tensor_type=o_layout,
sbp="SB",
axis=axis,
)
input_tensor_supa.copy_(input_tensor)
return input_tensor_supa

View File

@@ -0,0 +1,195 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/THUDM/ChatGLM2-6B
"""Inference-only ChatGLM model compatible with THUDM weights."""
from typing import Optional, Union
import torch
import torch.nn as nn
import vllm
from vllm.attention import Attention
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.models.chatglm import GLMMLP
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
def model_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
# unsqueeze for RMSNorm op
hidden_states = hidden_states.unsqueeze(0)
# Run encoder.
hidden_states = self.encoder(
hidden_states=hidden_states,
position_ids=positions,
)
# suqeeze to 2-d shape
return hidden_states.squeeze(0)
class GLMAttention_fit(nn.Module):
def __init__(
self,
config: ChatGLMConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.multi_query_attention = config.multi_query_attention
self.total_num_kv_heads = (config.multi_query_group_num
if config.multi_query_attention else
config.num_attention_heads)
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=config.add_bias_linear or config.add_qkv_bias,
quant_config=quant_config,
prefix=f"{prefix}.query_key_value",
)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense",
)
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
rope_ratio = getattr(config, "rope_ratio", 1.0)
max_positions = getattr(config, "seq_length", 8192)
# NOTE: THUDM/cogagent-9b-20241220 uses original_rope=False,
# which is equivalent to is_neox_style=True
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim // 2,
max_position=max_positions,
base=10000 * rope_ratio,
is_neox_style=False,
op_type="Chatglm2",
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn")
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
context_layer = self.attn(q, k, v)
attn_output, _ = self.dense(context_layer)
return attn_output
def GLMMLP__init__(
self,
config: ChatGLMConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super(GLMMLP, self).__init__()
self.add_bias = config.add_bias_linear
# Project to 4h.
self.dense_h_to_4h = MergedColumnParallelLinear(
config.hidden_size,
[config.ffn_hidden_size] * 2,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_h_to_4h",
)
self.dense_h_to_4h.no_fuse_act = True
self.activation_func = SiluAndMul()
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
config.ffn_hidden_size,
config.hidden_size,
bias=config.add_bias_linear,
quant_config=quant_config,
prefix=f"{prefix}.dense_4h_to_h",
)
def GLMMLP__forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, _ = self.dense_h_to_4h(hidden_states)
# [s, b, h]
output, _ = self.dense_4h_to_h(intermediate_parallel)
return output
vllm.model_executor.models.chatglm.GLMMLP.forward = GLMMLP__forward
vllm.model_executor.models.chatglm.GLMMLP.__init__ = GLMMLP__init__
vllm.model_executor.models.chatglm.ChatGLMModel.forward = model_forward
vllm.model_executor.models.chatglm.GLMAttention = GLMAttention_fit

View File

@@ -0,0 +1,65 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
import torch
import torch_br
from vllm.model_executor.models.clip import CLIPVisionEmbeddings
def clip_vision_embeddings_forward(self,
pixel_values: torch.Tensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
if self.patch_size == 14:
import torch_br.supa._debug as supa_debug
supa_debug.set_disable_zero_ws(False)
supa_debug.set_disable_zero_output_uma(False)
supa_debug.set_disable_zero_output_numa(False)
supa_debug.set_disable_reorder_zero(False)
#TODO(shouqing): this op need to do internal clear_zeros operation
patch_embeds = torch_br.supa_conv2d_knxn_snxn_p0x0_fwd(
pixel_values.to(dtype=target_dtype), self.patch_embedding.weight,
self.patch_size, self.patch_size, 0)
supa_debug.set_disable_zero_ws(True)
supa_debug.set_disable_zero_output_uma(True)
supa_debug.set_disable_zero_output_numa(True)
supa_debug.set_disable_reorder_zero(True)
else:
patch_embeds = self.patch_embedding(pixel_values.to(
dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
data_in_cpu = lambda t: t.device == torch.device('cpu')
if data_in_cpu(self.position_ids):
cur_device = torch.supa.current_device()
self.position_ids = self.position_ids.to(cur_device)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
#logger.debug('[Patch] patch CLIPVisionEmbeddings forward')
CLIPVisionEmbeddings.forward = clip_vision_embeddings_forward

View File

@@ -0,0 +1,51 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import TYPE_CHECKING
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastcore.basics import patch_to
from vllm.logger import init_logger
from vllm.model_executor.models.config import DeepseekV32ForCausalLM
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = init_logger(__name__)
@patch_to(DeepseekV32ForCausalLM)
def verify_and_update_config(vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config
# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32
# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")

View File

@@ -0,0 +1,99 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
import torch.nn as nn
from fastcore.basics import patch_to
from vllm.config import VllmConfig
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.models.deepseek_mtp import (
DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead)
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
# from vllm.model_executor.sampling_metadata import SamplingMetadata
# from vllm.model_executor.layers.sampler import get_sampler
@patch_to(DeepSeekMultiTokenPredictorLayer)
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super(DeepSeekMultiTokenPredictorLayer, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.eh_proj = nn.Linear(config.hidden_size * 2,
config.hidden_size,
bias=False)
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device="cuda")
else:
topk_indices_buffer = None
self.shared_head = SharedHead(config=config, quant_config=quant_config)
self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer)
@patch_to(DeepSeekMultiTokenPredictorLayer)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
torch.zeros_like(inputs_embeds), inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds.unsqueeze(0))
previous_hidden_states = self.hnorm(previous_hidden_states.unsqueeze(0))
fused_hidden_states = torch.cat([inputs_embeds, previous_hidden_states],
dim=-1)
hidden_states = self.eh_proj(fused_hidden_states)
hidden_states, residual = self.mtp_block(positions=positions,
hidden_states=hidden_states,
residual=None)
hidden_states = residual + hidden_states
return hidden_states.squeeze(0)
@patch_to(DeepSeekMultiTokenPredictor)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = (spec_step_idx % self.num_mtp_layers)
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
logits = self.logits_processor(
mtp_layer.shared_head.head,
mtp_layer.shared_head(
hidden_states.unsqueeze(0)).squeeze(0).contiguous())
return logits

View File

@@ -0,0 +1,924 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Any, Iterable, Optional, Union
import torch
from fastcore.basics import patch_to
from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
import vllm
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2ForCausalLM, DeepseekV2Model, FusedMoE, Indexer, PPMissingLayer,
default_weight_loader, get_spec_layer_idx_from_weight_name,
is_pp_missing_parameter, maybe_prefix, maybe_remap_kv_scale_name,
yarn_get_mscale)
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerMetadata
from vllm_br.v1.attention.backends.mla.indexer import (
SupaDeepseekV32IndexerBackend)
from .supa_module import (DeepseekV2MoE, MergedGateUpMLPSiluL2, SupaMLAModules,
SupaMultiHeadLatentAttention)
@patch_to(vllm.model_executor.models.deepseek_v2.DeepseekV32IndexerCache)
def get_attn_backend(self) -> AttentionBackend:
return SupaDeepseekV32IndexerBackend
class SupaDeepseekV2MLAAttention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
topk_indices_buffer: Optional[torch.Tensor] = None,
) -> None:
super().__init__()
self.is_v32 = hasattr(config, "index_topk")
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.fused_qkv_a_proj = None
self.kv_a_proj_with_mqa = None
self.q_a_proj = None
self.q_a_layernorm = None
self.q_b_proj = None
self.q_proj = None
if self.is_v32:
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size, [
self.q_lora_rank,
self.kv_lora_rank + self.qk_rope_head_dim
],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
self.fused_qkv_a_proj.no_need_cross = True
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
else:
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
if self.is_v32:
rope_scaling["rope_type"] = 'deepseek_yarn'
else:
rope_scaling["rope_type"] = 'deepseek_yarn_supa'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
if self.is_v32:
self.indexer: Optional[SupaIndexer] = SupaIndexer(
vllm_config, config, hidden_size, q_lora_rank, quant_config,
cache_config, topk_indices_buffer, f"{prefix}.indexer")
else:
self.indexer: Optional[SupaIndexer] = None
mla_modules = SupaMLAModules(
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
rotary_emb=self.rotary_emb,
o_proj=self.o_proj,
fused_qkv_a_proj=self.fused_qkv_a_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
q_a_layernorm=self.q_a_layernorm,
q_b_proj=self.q_b_proj,
q_proj=self.q_proj,
indexer=self.indexer,
is_sparse=self.is_v32,
topk_indices_buffer=topk_indices_buffer,
q_a_proj=self.q_a_proj,
)
self.mla_attn = SupaMultiHeadLatentAttention(
self.hidden_size,
self.num_local_heads,
self.scaling,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.q_lora_rank,
self.kv_lora_rank,
mla_modules,
cache_config,
quant_config,
prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.mla_attn(positions, hidden_states, is_ds_v32=self.is_v32)
def indexer_k_cache(
k: torch.Tensor, # [num_tokens, head_dim] # (8, 128)
kv_cache: torch.
Tensor, # [1, num_blocks, block_size, cache_stride] # (1, 1024, 2048, 128)
slot_mapping: torch.Tensor, # [num_tokens] # (8)
) -> None:
num_tokens = k.shape[0]
head_dim = k.shape[1]
# [TODO] kv_cache shape is not aligned with nv
cache_block_size = kv_cache.shape[-2]
for idx in range(num_tokens):
slot_idx = slot_mapping[idx]
k_idx = k[idx]
block_idx = slot_idx // cache_block_size
block_offset = slot_idx % cache_block_size
kv_cache[0][block_idx][
block_offset][:
head_dim] = k_idx # [TODO] kv cache stride is longer than head_dim
def bf16_mqa_logits(
q: torch.Tensor,
kv: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
):
seq_len_kv = kv.shape[0]
k = kv
q = q.float()
k = k.float()
mask_lo = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
>= cu_seqlen_ks[:, None])
mask_hi = (torch.arange(0, seq_len_kv, device="cuda")[None, :]
< cu_seqlen_ke[:, None])
mask = mask_lo & mask_hi
score = torch.einsum("mhd,nd->hmn", q, k)
logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
logits = logits.masked_fill(~mask, float("-inf"))
return logits
def _ref_fp8_paged_mqa_logits(
q: torch.Tensor,
kv_cache: torch.Tensor,
weights: torch.Tensor,
context_lens: torch.Tensor,
block_tables: torch.Tensor,
max_model_len: int,
):
batch_size, next_n, _, _ = q.size()
_, num_block, block_size, unkonw_size, head_dim = kv_cache.size(
) # [1, num_block, block_size, _]
num_block = num_block * 16
block_size = block_size // 16
kv_cache = kv_cache.view(num_block, block_size, unkonw_size, head_dim)
logits = torch.full(
[batch_size * next_n, max_model_len],
float("-inf"),
device=q.device,
dtype=torch.float32,
)
context_lens_list = context_lens.tolist()
for i in range(batch_size):
context_len = context_lens_list[i]
q_offsets = torch.arange(context_len - next_n,
context_len,
device="cuda")
weight_slice = (weights[i * next_n:(i + 1) * next_n, :].transpose(
0, 1).contiguous())
for block_rk in range(cdiv(context_len, block_size)):
block_idx = block_tables[i][block_rk]
qx, kx = q[i], kv_cache[block_idx]
k_offsets = torch.arange(
block_rk * block_size,
(block_rk + 1) * block_size,
device="cuda",
)
mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :]
<= q_offsets[:, None])
s = torch.where(
mask[None, :, :],
(qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(
logits.dtype),
float("-inf"),
)
s = torch.relu(s) * weight_slice[..., None]
s = s.sum(dim=0)
logits[
i * next_n:(i + 1) * next_n,
block_rk * block_size:(block_rk + 1) * block_size,
] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s,
float("-inf"))
return logits
def cp_gather_indexer_k_quant_cache(
kv_cache, # [1, num_blocks, block_size, head_dim + 1]
dst_value, # [cu_seq_lens[-1], head_dim]
dst_scale, # [cu_seq_lens[-1], 4]
block_table, # [batch_size, num_blocks]
cu_seq_lens, # [batch_size + 1, ]
batch_size,
):
_, num_blocks, block_size, _ = kv_cache.shape
# align to nv
num_blocks = num_blocks * 16
block_size = block_size // 16
head_dim = dst_value.shape[-1]
kv_cache = kv_cache.view(num_blocks, -1)
expected_value = []
# expected_scale = []
for b in range(batch_size):
s = cu_seq_lens[b + 1] - cu_seq_lens[b]
if s == 0:
continue
tot = cdiv(s, block_size)
blocks = block_table[b, :tot]
value = []
full_block = torch.arange(tot - 1,
device=kv_cache.device,
dtype=torch.int32)
# [TODO] not support index in tensor on br, run in cpu now
non_remaining_value = kv_cache.cpu()[
blocks.cpu()[full_block.cpu()], :block_size * head_dim].view(
-1, head_dim)
# non_remaining_scale = kv_cache[blocks[full_block],
# block_size * head_dim:].view(-1, 4)
remaining = s - (tot - 1) * block_size
value = torch.cat([
non_remaining_value,
kv_cache.cpu()[blocks[-1], :remaining * head_dim].view(
-1, head_dim)
],
dim=0)
# scale = torch.cat([
# non_remaining_scale,
# kv_cache[blocks[-1], block_size * head_dim:block_size * head_dim +
# remaining * 4].view(-1, 4)
# ],
# dim=0)
expected_value.append(value)
# expected_scale.append(scale)
gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim)
# gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4)
gather_value = gather_value.view(torch.bfloat16).to(dst_value.device)
# gather_scale = gather_scale.view(torch.float32)
dst_value.copy_(gather_value)
# dst_scale.copy_(gather_scale)
def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor],
) -> torch.Tensor:
# profile run
# NOTE(Chen): create the max possible flattened_kv. So that
# profile_run can get correct memory usage.
support_fp8 = False
if support_fp8:
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.uint8)
_k_fp8 = _flattened_kv[..., :head_dim].view(
torch.float8_e4m3fn).contiguous()
_k_scale = _flattened_kv[...,
head_dim:].view(torch.float32).contiguous()
else:
_flattened_kv = torch.empty([total_seq_lens, head_dim + 4],
device=k.device,
dtype=torch.bfloat16)
return topk_indices_buffer
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: str,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
scale_fmt: Optional[str],
topk_tokens: int,
head_dim: int,
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
k,
weights,
quant_block_size,
scale_fmt,
topk_tokens,
head_dim,
max_model_len,
total_seq_lens,
topk_indices_buffer,
)
assert topk_indices_buffer is not None
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
indexer_k_cache(
k,
kv_cache,
slot_mapping,
)
topk_indices_buffer[:hidden_states.shape[1]] = -1
if has_prefill:
prefill_metadata = attn_metadata.prefill
for chunk in prefill_metadata.chunks:
k_bf16 = torch.empty([chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.bfloat16)
k_scale = None
cp_gather_indexer_k_quant_cache(
kv_cache,
k_bf16,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
chunk.num_reqs,
)
logits = bf16_mqa_logits(
q_fp8[chunk.token_start:chunk.token_end],
k_bf16,
weights[chunk.token_start:chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
)
# [TODO] topk is not aligned with cpu if elements are -inf
topk_indices = logits.cpu().topk(min(topk_tokens,
logits.shape[-1]),
dim=-1)[1].supa()
topk_indices -= chunk.cu_seqlen_ks[:, None]
mask_lo = topk_indices >= 0
mask_hi = topk_indices - (chunk.cu_seqlen_ke -
chunk.cu_seqlen_ks)[:, None] < 0
mask = torch.full_like(topk_indices,
False,
dtype=torch.bool,
device=topk_indices.device)
mask = mask_lo & mask_hi
topk_indices = topk_indices.masked_fill(~mask, -1)
topk_indices_buffer[
chunk.token_start:chunk.token_end, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens)
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:])
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = _ref_fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
decode_metadata.seq_lens,
decode_metadata.block_table,
max_model_len=max_model_len,
)
# padded query len
current_device = padded_q_fp8_decode_tokens.device
padded_num_tokens = batch_size * next_n
positions = torch.arange(max_model_len,
device=current_device).unsqueeze(0).expand(
batch_size * next_n, -1)
row_indices = torch.arange(padded_num_tokens,
device=current_device) // next_n
next_n_offset = torch.arange(
padded_num_tokens,
device=padded_q_fp8_decode_tokens.device) % next_n
index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n +
next_n_offset).unsqueeze(1)
# index_end_pos: [B * N, 1]
mask = positions <= index_end_pos
# mask: [B * N, L]
logits = logits.masked_fill(~mask, float('-inf'))
# [TODO] topk is not supported
device = logits.device
logits = logits.to('cpu')
topk_indices = logits.topk(topk_tokens,
dim=-1)[1].to(torch.int32) # [B * N, K]
topk_indices = topk_indices.to(device)
# ensure we don't set indices for the top k
# that is out of range(masked already)
# this will happen if context length is shorter than K
topk_indices[topk_indices > index_end_pos] = -1
if decode_metadata.requires_padding:
# if padded, we need to unpack
# the topk indices removing padded tokens
topk_indices = unpack_seq_triton(
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
decode_lens)
topk_indices_buffer[:num_decode_tokens, :topk_indices.
shape[-1]] = topk_indices.to(dtype=torch.int32)
return topk_indices_buffer
class SupaIndexer(Indexer):
def __init__(self,
vllm_config: VllmConfig,
config: Union[DeepseekV2Config, DeepseekV3Config],
hidden_size: int,
q_lora_rank: Optional[int],
quant_config: Optional[QuantizationConfig],
cache_config: Optional[CacheConfig],
topk_indices_buffer: Optional[torch.Tensor] = None,
prefix: str = "") -> None:
super().__init__(
vllm_config=vllm_config,
config=config,
hidden_size=hidden_size,
q_lora_rank=q_lora_rank,
quant_config=quant_config,
cache_config=cache_config,
topk_indices_buffer=topk_indices_buffer,
prefix=prefix,
)
self.n_head = config.index_n_heads # 64
self.weights_proj = ReplicatedLinear(hidden_size,
self.n_head,
bias=False,
quant_config=None,
prefix=f"{prefix}.weights_proj")
self.k_cache.dtype = torch.bfloat16
self.k_cache.head_dim = config.index_head_dim
self.topk_indices_buffer.fill_(0)
def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions,
rotary_emb) -> torch.Tensor:
q, _ = self.wq_b(qr)
q = q.view(-1, self.n_head, self.head_dim)
q_pe, q_nope = torch.split(
q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
k, _ = self.wk(hidden_states)
k = k.view(-1, self.head_dim)
k = self.k_norm(k)
k_pe, k_nope = torch.split(
k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1)
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
q = torch.cat([q_pe, q_nope], dim=-1)
k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
support_fp8 = False
if support_fp8:
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = weights.unsqueeze(
-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q_fp8,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
else:
q = q.view(-1, self.n_head, self.head_dim)
weights, _ = self.weights_proj(hidden_states)
weights = weights.view(-1, self.n_head)
weights = weights.unsqueeze(
-1) * self.softmax_scale * self.n_head**-0.5
weights = weights.squeeze(-1)
return sparse_attn_indexer(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
q,
k,
weights,
self.quant_block_size,
self.scale_fmt,
self.topk_tokens,
self.head_dim,
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
)
@patch_to(DeepseekV2Model)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
hidden_states = hidden_states.unsqueeze(0)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0)
if hidden_states is not None else hidden_states,
"residual":
residual.squeeze(0) if residual is not None else residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
@patch_to(DeepseekV2ForCausalLM)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(DeepseekV2ForCausalLM, self).__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
model_config.use_ds_mla = True
is_v32 = hasattr(config, "index_topk")
if is_v32:
model_config.use_ds_mla_sparse = True
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = hasattr(
config, "q_lora_rank") and config.q_lora_rank is not None
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
@patch_to(DeepseekV2ForCausalLM)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if ((param_name == "fused_qkv_a_proj")
and name_mapped not in params_dict):
continue
else:
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
# logger.debug(f'skip {name}')
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
# logger.debug(f'skip {name}')
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
# logger.debug(f'skip {name}')
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
loaded_params.add(name)
return loaded_params
vllm.model_executor.models.deepseek_v2.DeepseekV2MLP = MergedGateUpMLPSiluL2
logger.debug('[Patch] patch DeepSeekV2 MLP with MergedGateUpMLPSiluL2')
vllm.model_executor.models.deepseek_v2.DeepseekV2MoE = DeepseekV2MoE
logger.debug('[Patch] patch DeepSeekV2 MoE with DeepseekV2MoE')
vllm.model_executor.models.deepseek_v2.DeepseekV2MLAAttention = SupaDeepseekV2MLAAttention
logger.debug('[Patch] patch DeepSeekV2 MLA with SupaDeepseekV2MLAAttention')
vllm.model_executor.models.deepseek_v2.Indexer = SupaIndexer
logger.debug('[Patch] patch DeepSeekV2 Indexer with SupaIndexer')
vllm.model_executor.models.deepseek_v2.MultiHeadLatentAttention = SupaMultiHeadLatentAttention
logger.debug(
'[Patch] patch DeepSeekV2 MultiHeadLatentAttention with SupaMultiHeadLatentAttention'
)
# vllm.model_executor.models.deepseek_v2.DeepseekV2ForCausalLM.packed_modules_mapping = {
# "gate_up_proj": ["gate_proj", "up_proj"],
# # "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]
# }
# logger.debug(
# '[Patch] patch DeepseekV2ForCausalLM with SupportsQuant packed_modules_mapping'
# )

View File

@@ -0,0 +1,299 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Copyright 2025 The Zhipu AI team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4-0414 model compatible with HuggingFace weights."""
#from typing import Any, Optional
#
#import torch
#from fastcore.basics import patch_to
#from transformers import Glm4Config
#
#import vllm
#from vllm.attention import Attention, AttentionType
#from vllm.config import CacheConfig
#from vllm.distributed import get_tensor_model_parallel_world_size
#from vllm.model_executor.layers.linear import (QKVParallelLinear,
# RowParallelLinear)
#from vllm.model_executor.layers.quantization import QuantizationConfig
#from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding,
# RotaryEmbedding)
#from vllm.model_executor.models.glm4 import Glm4Attention
#
#_ROPE_DICT: dict[tuple, RotaryEmbedding] = {}
#def get_rope_0_9_2(
# head_size: int,
# rotary_dim: int,
# max_position: int,
# base: float,
# is_neox_style: bool = True,
# rope_scaling: Optional[dict[str, Any]] = None,
# dtype: Optional[torch.dtype] = None,
# partial_rotary_factor: float = 1.0,
# dual_chunk_attention_config: Optional[dict[str, Any]] = None,
#) -> RotaryEmbedding:
#
# if dtype is None:
# dtype = torch.get_default_dtype()
# if rope_scaling is not None:
# # Transforms every value that is a list into a tuple for caching calls
# rope_scaling_tuple = {
# k: tuple(v) if isinstance(v, list) else v
# for k, v in rope_scaling.items()
# }
# rope_scaling_args = tuple(rope_scaling_tuple.items())
# else:
# rope_scaling_args = None
#
# if dual_chunk_attention_config is not None:
# dual_chunk_attention_tuple = {
# k: tuple(v) if isinstance(v, list) else v
# for k, v in dual_chunk_attention_config.items()
# if k != "sparse_attention_config"
# }
# dual_chunk_attention_args = tuple(dual_chunk_attention_tuple.items())
# else:
# dual_chunk_attention_args = None
#
# if partial_rotary_factor < 1.0:
# rotary_dim = int(rotary_dim * partial_rotary_factor)
# key = (head_size, rotary_dim, max_position, base, is_neox_style,
# rope_scaling_args, dual_chunk_attention_args, dtype)
# if key in _ROPE_DICT:
# return _ROPE_DICT[key]
#
# if dual_chunk_attention_config is not None:
# extra_kwargs = {
# k: v
# for k, v in dual_chunk_attention_config.items()
# if k in ("chunk_size", "local_size")
# }
# rotary_emb = DualChunkRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype,
# **extra_kwargs)
# elif not rope_scaling:
# rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
# is_neox_style, dtype)
# else:
# scaling_type = rope_scaling["rope_type"]
#
# if scaling_type == "llama3":
# scaling_factor = rope_scaling["factor"]
# low_freq_factor = rope_scaling["low_freq_factor"]
# high_freq_factor = rope_scaling["high_freq_factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype,
# scaling_factor, low_freq_factor,
# high_freq_factor,
# original_max_position)
# elif scaling_type == "mllama4":
# rotary_emb = Llama4VisionRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style, dtype)
# elif scaling_type == "default":
# if "mrope_section" in rope_scaling:
# rotary_emb = MRotaryEmbedding(
# head_size,
# rotary_dim,
# max_position,
# base,
# is_neox_style,
# dtype,
# mrope_section=rope_scaling["mrope_section"],
# )
# else:
# rotary_emb = RotaryEmbedding(
# head_size,
# rotary_dim,
# max_position,
# base,
# is_neox_style,
# dtype,
# )
# elif scaling_type == "linear":
# scaling_factor = rope_scaling["factor"]
# rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style,
# scaling_factor, dtype)
# elif scaling_type == "ntk":
# scaling_factor = rope_scaling["factor"]
# mixed_b = rope_scaling.get('mixed_b', None)
# rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim,
# max_position, base,
# is_neox_style,
# scaling_factor, dtype,
# mixed_b)
# elif scaling_type == "dynamic":
# if "alpha" in rope_scaling:
# scaling_alpha = rope_scaling["alpha"]
# rotary_emb = DynamicNTKAlphaRotaryEmbedding(
# head_size, rotary_dim, max_position, base, is_neox_style,
# scaling_alpha, dtype)
# elif "factor" in rope_scaling:
# scaling_factor = rope_scaling["factor"]
# rotary_emb = DynamicNTKScalingRotaryEmbedding(
# head_size, rotary_dim, max_position, base, is_neox_style,
# scaling_factor, dtype)
# else:
# raise ValueError("Dynamic rope scaling must contain either "
# "'alpha' or 'factor' field")
# elif scaling_type == "yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow")
# }
# rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
# original_max_position,
# base, is_neox_style,
# scaling_factor, dtype,
# **extra_kwargs)
# elif scaling_type == "deepseek_yarn":
# scaling_factor = rope_scaling["factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# # assert max_position == original_max_position * scaling_factor
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("extrapolation_factor", "attn_factor", "beta_fast",
# "beta_slow", "mscale", "mscale_all_dim")
# }
# rotary_emb = DeepseekScalingRotaryEmbedding(
# head_size, rotary_dim, original_max_position, base,
# is_neox_style, scaling_factor, dtype, **extra_kwargs)
# elif scaling_type == "longrope":
# short_factor = rope_scaling["short_factor"]
# long_factor = rope_scaling["long_factor"]
# original_max_position = rope_scaling[
# "original_max_position_embeddings"]
# extra_kwargs = {
# k: v
# for k, v in rope_scaling.items()
# if k in ("short_mscale", "long_mscale")
# }
# rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(
# head_size, rotary_dim, max_position, original_max_position,
# base, is_neox_style, dtype, short_factor, long_factor,
# **extra_kwargs)
# else:
# raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
# _ROPE_DICT[key] = rotary_emb
# return rotary_emb
#
#
#@patch_to(vllm.model_executor.models.glm4.Glm4Attention)
#def __init__(self,
# config: Glm4Config,
# hidden_size: int,
# num_heads: int,
# num_kv_heads: int,
# max_position: int = 4096 * 32,
# head_dim: Optional[int] = None,
# qkv_bias: bool = False,
# rope_theta: float = 10000,
# cache_config: Optional[CacheConfig] = None,
# quant_config: Optional[QuantizationConfig] = None,
# rope_scaling: Optional[tuple] = None,
# prefix: str = "",
# attn_type: str = AttentionType.DECODER) -> None:
# super(Glm4Attention, self).__init__()
# self.hidden_size = hidden_size
# tp_size = get_tensor_model_parallel_world_size()
# self.total_num_heads = num_heads
# assert self.total_num_heads % tp_size == 0
# self.num_heads = self.total_num_heads // tp_size
# self.total_num_kv_heads = num_kv_heads
# if self.total_num_kv_heads >= tp_size:
# # Number of KV heads is greater than TP size, so we partition
# # the KV heads across multiple tensor parallel GPUs.
# assert self.total_num_kv_heads % tp_size == 0
# else:
# # Number of KV heads is less than TP size, so we replicate
# # the KV heads across multiple tensor parallel GPUs.
# assert tp_size % self.total_num_kv_heads == 0
# partial_rotary_factor = getattr(config, "partial_rotary_factor", 0.5)
# self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# self.head_dim = head_dim or hidden_size // self.total_num_heads
# self.rotary_dim = self.head_dim
# self.q_size = self.num_heads * self.head_dim
# self.kv_size = self.num_kv_heads * self.head_dim
# self.scaling = self.head_dim**-0.5
# self.rope_theta = rope_theta
# self.qkv_proj = QKVParallelLinear(
# hidden_size,
# self.head_dim,
# self.total_num_heads,
# self.total_num_kv_heads,
# bias=qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv_proj",
# )
# self.o_proj = RowParallelLinear(
# self.total_num_heads * self.head_dim,
# hidden_size,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.o_proj",
# )
# self.rotary_emb = get_rope_0_9_2(
# self.head_dim,
# rotary_dim=self.rotary_dim,
# max_position=max_position,
# base=self.rope_theta,
# rope_scaling=rope_scaling,
# partial_rotary_factor=partial_rotary_factor,
# is_neox_style=False,
# )
# self.attn = Attention(self.num_heads,
# self.head_dim,
# self.scaling,
# num_kv_heads=self.num_kv_heads,
# cache_config=cache_config,
# quant_config=quant_config,
# prefix=f"{prefix}.attn",
# attn_type=attn_type)

View File

@@ -0,0 +1,795 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/Glm4v/modeling_Glm4v.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The ZhipuAI Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping
from functools import partial
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_br
from einops import rearrange, repeat
from torch_br.contrib import SueagerScaledDotProductAttention
import vllm
import vllm.model_executor.models.glm4
import vllm.model_executor.models.llama
import vllm.model_executor.models.qwen2_vl
import vllm_br.envs as br_envs
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
from vllm.distributed import utils as dist_utils
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.glm4_1v import (Glm4vForConditionalGeneration,
Glm4vVisionBlock,
Glm4vVisionMLP,
Glm4vVisionTransformer)
from vllm.model_executor.models.utils import (init_vllm_registered_model,
is_pp_missing_parameter,
maybe_prefix)
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from ..layers.activation import SiluAndMul
from ..layers.br_utils import is_br166_device
logger = init_logger(__name__)
def Glm4vVisionMLP_init_fit(self,
in_features: int,
hidden_features: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False):
super(Glm4vVisionMLP, self).__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
self.act_fn = SiluAndMul()
def Glm4vVisionMLP_forward_fit(self, x: torch.Tensor):
x, _ = self.gate_up_proj(x)
#x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(
gathered_tensors,
local_tensor,
group=parallel_state.get_tp_group().device_group,
)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1)
for tensor in gathered_tensors
]
ordered_tensors = [
tensor for pair in zip(*gathered_tensors_split) for tensor in pair
]
result_tensor = torch.cat(ordered_tensors, dim=-1)
return result_tensor
class Glm4vVisionAttention_fit(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = (0 if use_data_parallel else
parallel_state.get_tensor_model_parallel_rank())
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
#self.qkv = QKVParallelLinear(
# hidden_size=embed_dim,
# head_size=self.hidden_size_per_attention_head,
# total_num_heads=num_heads,
# total_num_kv_heads=num_heads,
# bias=False,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv",
#)
#self.proj = RowParallelLinear(
# input_size=projection_size,
# output_size=embed_dim,
# quant_config=quant_config,
# prefix=f"{prefix}.proj",
# bias=False,
#)
qkv_output_size = (num_heads +
2 * num_heads) * self.hidden_size_per_attention_head
self.qkv = nn.Linear(embed_dim, qkv_output_size, bias=False)
self.proj = nn.Linear(projection_size, embed_dim, bias=False)
self.sueager_attention = SueagerScaledDotProductAttention()
# Detect attention implementation.
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
# self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
_Backend.XFORMERS,
}:
raise RuntimeError(
f"GLM-4V does not support {self.attn_backend} backend now.")
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape
if self.tp_size > 1:
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
self.tp_size)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(
dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size,
)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (
seq_len,
bs,
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
# x, _ = self.qkv(x)
x = self.qkv(x)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = glm_apply_rotary_pos_emb_vision(q, rotary_pos_emb)
k = glm_apply_rotary_pos_emb_vision(k, rotary_pos_emb)
if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
output = flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen,
dropout_p=0,
causal=False,
)
context_layer = rearrange(output,
"(b s) ... -> b s ...",
b=batch_size)
elif self.attn_backend == _Backend.TORCH_SDPA:
# Execute attention entry by entry for speed & less VRAM.
outputs = []
for i in range(1, len(cu_seqlens)):
start_idx = cu_seqlens[i - 1]
end_idx = cu_seqlens[i]
q_i = q[:, start_idx:end_idx]
k_i = k[:, start_idx:end_idx]
v_i = v[:, start_idx:end_idx]
q_i, k_i, v_i = (rearrange(x, "b s h d -> s h b d")
for x in [q_i, k_i, v_i])
output_i = torch_br.sueager_scaled_dot_product_attention_fwd(
q_i.squeeze(),
k_i.squeeze(),
v_i.squeeze(),
mask=None,
dropout_prob=0.0,
is_causal=False,
scale=1 / math.sqrt(q_i.shape[-1]),
algorithm="FMHA",
)[0]
output_i = output_i.unsqueeze(0)
if is_br166_device():
output_tmp = torch_br._empty_ut_only(output_i.shape,
"COLMAJOR",
is_numa=False,
sbp="BB",
axis=0,
dtype=torch.bfloat16)
output_tmp.copy_(output_i)
output_i = output_tmp
output_i = rearrange(output_i, "b s h d -> h b s d")
outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
attn_bias = BlockDiagonalMask.from_seqlens(q_seqlen=seqlens,
kv_seqlen=None,
device=q.device)
context_layer = xops.memory_efficient_attention_forward(
q, k, v, attn_bias=attn_bias, p=0, scale=None)
context_layer = rearrange(context_layer,
"b s h d -> s b (h d)").contiguous()
# output, _ = self.proj(context_layer)
output = self.proj(context_layer)
return output
def Glm4vVisionBlock_init_fit(
self,
dim: int,
num_heads: int,
mlp_hidden_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super(Glm4vVisionBlock, self).__init__()
if norm_layer is None:
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.attn = Glm4vVisionAttention_fit(
embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_data_parallel=use_data_parallel,
)
self.mlp = Glm4vVisionMLP(
dim,
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
use_data_parallel=use_data_parallel,
)
def Glm4vVisionBlock_forward_fit(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
#from fpdb import ForkedPdb
normx = self.norm1(x)
cur_device = torch.supa.current_device()
x = x + self.attn(
normx,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb.to(cur_device),
max_seqlen=max_seqlen,
seqlens=seqlens,
)
x = x + self.mlp(self.norm2(x))
return x
def Llama_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
split_params_mapping = [
(".gate_up_proj", ".gate_proj", ".up_proj"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if (self.quant_config is not None
and (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight
if loaded_weight.dim() == 0 else loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
if "scale" in name:
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
do_mapping_flag = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
do_mapping_flag = True
loaded_params.add(name)
break
if not do_mapping_flag:
for gate_up, gate, up in split_params_mapping:
if gate_up not in name:
continue
gate_name = name.replace(gate_up, gate)
up_name = name.replace(gate_up, up)
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param_gate = params_dict[gate_name]
param_up = params_dict[up_name]
assert loaded_weight.shape[0] == param_gate.shape[
0] + param_up.shape[0], "gate up shape is not match"
weight_loader_gate = param_gate.weight_loader
weight_loader_gate(param_gate, loaded_weight[
:param_gate.shape[0],
])
weight_loader_up = param_up.weight_loader
weight_loader_up(param_up, loaded_weight[
param_gate.shape[0]:,
])
do_mapping_flag = True
loaded_params.add(gate_name)
loaded_params.add(up_name)
break
if not do_mapping_flag:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
if not interleaved:
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1, x2 = x[..., ::2], x[..., 1::2]
return rearrange(torch.stack((-x2, x1), dim=-1),
"... d two -> ... (d two)",
two=2)
def glm_apply_rotary_emb_torch(x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
interleaved: bool = False) -> torch.Tensor:
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim = cos.shape[-1] * 2
assert ro_dim <= x.shape[-1]
cos = repeat(
cos,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
sin = repeat(
sin,
"... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)")
cos = cos.unsqueeze(2)
sin = sin.unsqueeze(2)
res = torch.cat(
[
x[..., :ro_dim] * cos +
rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]
],
dim=-1,
)
return res
def glm_apply_rotary_pos_emb_vision(t: torch.Tensor,
freqs: torch.Tensor) -> torch.Tensor:
t_ = t.float()
cos = freqs.cos()
sin = freqs.sin()
apply_rotary_emb = glm_apply_rotary_emb_torch
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
output = apply_rotary_emb(t_, cos, sin).type_as(t)
return output
def LlamaMLP_glm4_1v_forward(self, x):
x, _ = self.gate_up_proj(x)
x, _ = self.down_proj(x)
return x
def Glm4Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if is_br166_device():
q_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.q_size),
"COLMAJOR",
is_numa=False,
sbp="SB",
axis=2,
dtype=torch.bfloat16)
k_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.kv_size),
"COLMAJOR",
is_numa=False,
sbp="SB",
axis=2,
dtype=torch.bfloat16)
q_tmp.copy_(q)
k_tmp.copy_(k)
q = q_tmp
k = k_tmp
q_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.q_size),
"COLMAJOR",
is_numa=False,
sbp="BB",
axis=0,
dtype=torch.bfloat16)
k_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.kv_size),
"COLMAJOR",
is_numa=False,
sbp="BB",
axis=0,
dtype=torch.bfloat16)
q_tmp.copy_(q)
k_tmp.copy_(k)
q = q_tmp
k = k_tmp
q, k = self.rotary_emb(positions, q, k)
if is_br166_device():
q_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.q_size),
"COLMAJOR",
is_numa=False,
sbp="SB",
axis=2,
dtype=torch.bfloat16)
k_tmp = torch_br._empty_ut_only(
(qkv.shape[0], qkv.shape[1], self.kv_size),
"COLMAJOR",
is_numa=False,
sbp="SB",
axis=2,
dtype=torch.bfloat16)
q_tmp.copy_(q)
k_tmp.copy_(k)
q = q_tmp
k = k_tmp
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
max_image_tokens = self.get_max_image_tokens()
target_width, target_height = self.get_image_size_with_most_features()
max_video_tokens = self.get_num_video_tokens(image_width=target_width,
image_height=target_height,
num_frames=1)
return {"image": max_image_tokens, "video": max_video_tokens}
def glm4v_init(self, *, vllm_config: VllmConfig, prefix: str = ""):
super(Glm4vForConditionalGeneration, self).__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
prefix=maybe_prefix(prefix, "visual"),
use_data_parallel=self.use_data_parallel,
)
if config.model_type == "glm4v":
architectures = ["Glm4ForCausalLM"]
elif config.model_type == "glm4v_moe":
architectures = ["Glm4MoeForCausalLM"]
else:
architectures = None
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
hf_config=config.text_config,
prefix=maybe_prefix(prefix, "language_model"),
architectures=architectures)
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
br_envs.VLLM_BR_USE_MROPE_0_9_2 = True
def Glm4vPatchMerger_forward(self, x: torch.Tensor):
x, _ = self.proj(x)
if is_br166_device():
output_tmp = torch_br._empty_ut_only(x.shape,
"COLMAJOR",
is_numa=False,
sbp="BB",
axis=0,
dtype=torch.bfloat16)
output_tmp.copy_(x)
x = output_tmp
x = self.extra_activation_func(self.post_projection_norm(x))
gate_up, _ = self.gate_up_proj(x)
# x = self.act_fn(gate_up)
x = gate_up
x, _ = self.down_proj(x)
return x
def Glm4vVisionEmbeddings_forward(self, embeddings, lengths, image_shapes,
h_coords, w_coords) -> torch.Tensor:
pos_embed_weight = self.position_embedding.weight
hidden_size = pos_embed_weight.shape[1]
total_seq = h_coords.shape[0]
device = pos_embed_weight.device
# Move coordinates to correct device
h_coords, w_coords = h_coords.to(device), w_coords.to(device)
# Handle empty sequence case
if total_seq == 0:
adapted_pos_embed = torch.empty(0,
hidden_size,
device=device,
dtype=pos_embed_weight.dtype)
else:
# Convert inputs to tensors if needed
if isinstance(lengths, list):
lengths = torch.tensor(lengths, device=device, dtype=torch.long)
if not isinstance(image_shapes, torch.Tensor):
image_shapes = torch.tensor(image_shapes,
device=device,
dtype=torch.long)
# Prepare 2D position embedding
orig_size_sq = pos_embed_weight.shape[0]
orig_size = int(orig_size_sq**0.5)
pos_embed_2d = (pos_embed_weight.view(orig_size,
orig_size, hidden_size).permute(
2, 0, 1).unsqueeze(0))
pos_embed_2d = pos_embed_2d.to(torch.float32)
# Calculate target dimensions for each patch
# Add bounds checking for data parallel mode
if len(lengths) > image_shapes.shape[0]:
# In data parallel mode, some GPUs might not have all
# image shapes
# Use available image shapes, cycling if necessary
target_h_list = []
target_w_list = []
for i in range(len(lengths)):
# Cycle through available shapes
shape_idx = i % image_shapes.shape[0]
target_h_list.append(image_shapes[shape_idx,
1].repeat(lengths[i]))
target_w_list.append(image_shapes[shape_idx,
2].repeat(lengths[i]))
target_h = torch.cat(target_h_list).to(device=device,
dtype=torch.float32)
target_w = torch.cat(target_w_list).to(device=device,
dtype=torch.float32)
else:
target_h = torch.cat([
image_shapes[i, 1].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
target_w = torch.cat([
image_shapes[i, 2].repeat(lengths[i])
for i in range(len(lengths))
]).to(device=device, dtype=torch.float32)
# Normalize coordinates to [-1, 1] range for grid_sample
h_coords = h_coords.to(device=device, dtype=torch.float32)
w_coords = w_coords.to(device=device, dtype=torch.float32)
norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
# Create sampling grid
grid = (torch.stack((norm_w, norm_h),
dim=-1).unsqueeze(0).unsqueeze(2))
# Perform bicubic interpolation
interpolated_embed_fp32 = F.grid_sample(
pos_embed_2d,
grid,
mode="bicubic",
align_corners=False,
padding_mode="border",
)
# Reshape and convert back to original dtype
adapted_pos_embed_fp32 = (
interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0))
adapted_pos_embed = adapted_pos_embed_fp32.to(
pos_embed_weight.dtype).to(embeddings.device)
# Add adapted position encoding to embeddings
embeddings = embeddings + adapted_pos_embed
return embeddings
#LlamaModel.load_weights = Llama_load_weights
vllm.model_executor.models.llama.LlamaMLP.forward = LlamaMLP_glm4_1v_forward
vllm.model_executor.models.glm4.Glm4Attention.forward = Glm4Attention_forward
#vllm.model_executor.models.glm4_1v.Glm4vVisionAttention = Glm4vVisionAttention_fit
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.__init__ = Glm4vVisionBlock_init_fit
vllm.model_executor.models.glm4_1v.Glm4vVisionBlock.forward = Glm4vVisionBlock_forward_fit
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.forward = Glm4vVisionMLP_forward_fit
vllm.model_executor.models.glm4_1v.Glm4vVisionMLP.__init__ = Glm4vVisionMLP_init_fit
vllm.model_executor.models.glm4_1v.Glm4vProcessingInfo.get_mm_max_tokens_per_item = get_mm_max_tokens_per_item
vllm.model_executor.models.glm4_1v.Glm4vForConditionalGeneration.__init__ = glm4v_init
vllm.model_executor.models.glm4_1v.Glm4vPatchMerger.forward = Glm4vPatchMerger_forward
vllm.model_executor.models.glm4_1v.Glm4vVisionEmbeddings.forward = Glm4vVisionEmbeddings_forward

View File

@@ -0,0 +1,475 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright 2025 The ZhipuAI Team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only GLM-4.5 model compatible with HuggingFace weights."""
import typing
from collections.abc import Callable, Iterable
from typing import Optional, Union
import torch
import torch_br
from torch import nn
from transformers.models.glm4_moe import Glm4MoeConfig
import vllm
import vllm.model_executor.models.glm4_moe
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import (get_ep_group, get_pp_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.glm4_moe import (
Glm4MoeAttention, Glm4MoeDecoderLayer, get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_br.v1.attention.backends.attention_v1 import (
SUPAFlashAttentionMetadata)
from .supa_module import MergedGateUpMLPSiluL2
class Glm4MoE(nn.Module):
def __init__(
self,
config: Glm4MoeConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.ep_group = get_ep_group().device_group
self.ep_rank = self.ep_group.rank()
self.ep_size = self.ep_group.size()
self.n_routed_experts: int = config.n_routed_experts
self.n_shared_experts: int = config.n_shared_experts
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
params_dtype=torch.float32,
prefix=f"{prefix}.gate")
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32))
# Load balancing settings.
vllm_config = get_current_vllm_config()
eplb_config = vllm_config.parallel_config.eplb_config
self.enable_eplb = enable_eplb
self.n_redundant_experts = eplb_config.num_redundant_experts
self.n_logical_experts = self.n_routed_experts
self.n_physical_experts = (self.n_logical_experts +
self.n_redundant_experts)
self.n_local_physical_experts = self.n_physical_experts // self.ep_size
self.physical_expert_start = (self.ep_rank *
self.n_local_physical_experts)
self.physical_expert_end = (self.physical_expert_start +
self.n_local_physical_experts)
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func="sigmoid",
# we do scaling outside, set factor to 1.0 to avoid double mul
routed_scaling_factor=1.0,
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = MergedGateUpMLPSiluL2(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
# reduce_results=self.experts.must_reduce_shared_expert_outputs(
# ),
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
orig_shape = hidden_states.shape
assert self.n_shared_experts is not None, 'n_shared_experts must be set'
# NOTE: gate has been fused with shared_experts, no more single gate call
# and we packed router weights, shared_experts weights and down weights in a tuple
tuple_router_shared_expert_weight = (
self.gate.weight, self.shared_experts.gate_up_proj.weight,
self.shared_experts.down_proj.weight)
hidden_states = hidden_states.view(-1, orig_shape[-1])
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=tuple_router_shared_expert_weight)
if hasattr(final_hidden_states, 'all_reduced'):
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
delattr(final_hidden_states, 'all_reduced')
elif self.tp_size > 1:
final_hidden_states = (
self.experts.maybe_all_reduce_tensor_model_parallel(
final_hidden_states))
return final_hidden_states.view(orig_shape)
vllm.model_executor.models.glm4_moe.Glm4MoE = Glm4MoE
def Glm4MoeAttention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
## for dummy run
return hidden_states
seq_len = hidden_states.shape[-2]
decode_seql = 512
if seq_len <= decode_seql:
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
if kv_cache is not None:
if hasattr(self.qkv_proj, "qweight"):
qkv_weight = self.qkv_proj.qweight.data
qkv_scales = self.qkv_proj.scales.data
elif hasattr(self.qkv_proj, "weight_packed"):
qkv_weight = self.qkv_proj.weight_packed.data
qkv_scales = self.qkv_proj.weight_scale.data
else:
qkv_weight = self.qkv_proj.weight
qkv_scales = None
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
rotary_dim=self.rotary_emb.rotary_dim,
bias=self.qkv_proj.bias,
scales=qkv_scales)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = False
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
else:
return hidden_states
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = torch_br.br_fused_split_rms_rope_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
positions,
rotary_dim=self.rotary_emb.rotary_dim)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = True
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
vllm.model_executor.models.glm4_moe.Glm4MoeAttention.forward = Glm4MoeAttention_forward
def Glm4MoeDecoderLayer__init__(
self,
config: Glm4MoeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
enable_eplb: bool = False,
) -> None:
super(Glm4MoeDecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
131072)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.self_attn = Glm4MoeAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=config.head_dim,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=config.attention_bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
use_qk_norm=config.use_qk_norm,
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace):
self.mlp = Glm4MoE(
config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
enable_eplb=enable_eplb,
)
else:
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
vllm.model_executor.models.glm4_moe.Glm4MoeDecoderLayer.__init__ = Glm4MoeDecoderLayer__init__
def Glm4MoeModel_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = residual.unsqueeze(0) # NOTE: SUPA wants 3D input
hidden_states = hidden_states.unsqueeze(0)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0)
if hidden_states is not None else hidden_states,
"residual":
residual.squeeze(0) if residual is not None else residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
vllm.model_executor.models.glm4_moe.Glm4MoeModel.forward = Glm4MoeModel_forward
def Glm4MoeModel_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
torch.supa.empty_cache()
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# weight layout infer
if name.find("norm.weight") != -1 or name.find(
"e_score_correction_bias") != -1:
param.data = param.data.to(torch.float32)
if name.find("gate.weight") != -1:
param.data = param.data.to(torch.bfloat16)
torch.supa.empty_cache()
loaded_params.add(name)
return loaded_params
vllm.model_executor.models.glm4_moe.Glm4MoeModel.load_weights = Glm4MoeModel_load_weights

View File

@@ -0,0 +1,358 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from collections.abc import Iterable
from typing import Optional
import torch
import torch.distributed as dist
import torch_br
from torch import nn
from transformers import GptOssConfig
import vllm
import vllm.model_executor.models.gpt_oss
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import (extract_layer_index,
is_pp_missing_parameter)
from vllm.sequence import IntermediateTensors
from vllm.utils import cdiv
from vllm_br import envs
class OAIAttention(nn.Module):
def __init__(
self,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
):
super().__init__()
self.layer_idx = extract_layer_index(prefix)
self.head_dim = config.head_dim
self.num_attention_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
dtype=torch.float32,
rope_scaling={
"rope_type":
"yarn",
"factor":
config.rope_scaling["factor"],
"original_max_position_embeddings":
config.rope_scaling["original_max_position_embeddings"],
"beta_fast":
config.rope_scaling["beta_fast"],
"beta_slow":
config.rope_scaling["beta_slow"],
},
is_neox_style=True,
)
tp_size = get_tensor_model_parallel_world_size()
attention_sink_dtype = torch.float32
self.sinks = torch.nn.Parameter(
torch.empty(config.num_attention_heads // tp_size,
dtype=attention_sink_dtype,
requires_grad=False))
self.q_size = self.num_attention_heads * self.head_dim // tp_size
self.kv_size = self.num_key_value_heads * self.head_dim // tp_size
self.scaling = self.head_dim**-0.5
self.rope_theta = config.rope_theta
self.qkv = QKVParallelLinear(
hidden_size=self.hidden_size,
head_size=self.head_dim,
total_num_heads=self.num_attention_heads,
total_num_kv_heads=self.num_key_value_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.num_attention_heads * self.head_dim,
output_size=self.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.num_local_attention_heads = config.num_attention_heads // tp_size
self.num_local_key_value_heads = config.num_key_value_heads // tp_size
# Only apply sliding window to every other layer
sliding_window = (config.sliding_window if self.layer_idx %
2 == 0 else None)
self.attn = Attention(
self.num_local_attention_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_local_key_value_heads,
cache_config=cache_config,
quant_config=quant_config,
per_layer_sliding_window=sliding_window,
attn_type=AttentionType.DECODER,
prefix=f"{prefix}.attn",
sinks=self.sinks,
)
def forward(self, hidden_states: torch.Tensor,
positions: torch.Tensor) -> torch.Tensor:
qkv, _ = self.qkv(hidden_states)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
q, k, v = torch_br.split_w_sbp_infer(
qkv, [self.q_size, self.kv_size, self.kv_size])
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q, k = self.rotary_emb(positions, q, k)
v = v.contiguous()
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
vllm.model_executor.models.gpt_oss.OAIAttention = OAIAttention
class MLPBlock(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
layer_idx: int,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
self.layer_idx = layer_idx
self.num_experts = config.num_local_experts
self.experts_per_token = config.num_experts_per_tok
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
self.router = torch.nn.Linear(config.hidden_size,
config.num_local_experts,
dtype=torch.bfloat16)
assert config.intermediate_size % self.world_size == 0
self.experts = FusedMoE(num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
reduce_results=True,
renormalize=True,
quant_config=quant_config,
prefix=f"{prefix}.experts",
apply_router_weight_on_input=False,
has_bias=True,
activation="swigluoai",
is_sequence_parallel=self.is_sequence_parallel)
def forward(self, x: torch.Tensor) -> torch.Tensor:
final_hidden_states = self.experts(hidden_states=x.squeeze(0),
router_logits=self.router.weight)
if hasattr(final_hidden_states, 'all_reduced'):
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
delattr(final_hidden_states, 'all_reduced')
elif self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states
vllm.model_executor.models.gpt_oss.MLPBlock = MLPBlock
def GptOssModel_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
x = inputs_embeds
else:
x = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
x = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
residual = residual.unsqueeze(0)
x = x.unsqueeze(0)
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if i in self.aux_hidden_state_layers:
aux_hidden_states.append(x if residual is None else x + residual)
x, residual = layer(x, positions, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
x.squeeze(0),
"residual":
residual.squeeze(0) if residual is not None else None,
})
x, _ = self.norm(x, residual)
if len(aux_hidden_states) > 0:
return x, aux_hidden_states
return x.squeeze(0)
vllm.model_executor.models.gpt_oss.GptOssModel.forward = GptOssModel_forward
def GptOssModel_load_weights_other(
self,
ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int,
head_start: int,
weights: Iterable[tuple[str, torch.Tensor]],
stacked_params_mapping: list[tuple[str, ...]],
) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
use_ep = self.parallel_config.enable_expert_parallel
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
intermediate_size = self.config.intermediate_size
per_rank_intermediate_size = cdiv(intermediate_size, tp_size)
# Calculate common slicing bounds for current rank
tp_rank_start = tp_rank * per_rank_intermediate_size
tp_rank_end = min((tp_rank + 1) * per_rank_intermediate_size,
intermediate_size)
for name, weight in weights:
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if ".w13_weight" in name:
# Handle MLP gate and up projection weights
# Extract gate and up projection parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, :, 2 * tp_rank_start:2 * tp_rank_end]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_weight" in name:
# Handle MLP down projection weights
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, tp_rank_start:tp_rank_end, :]
narrow_weight = narrow_weight.permute(0, 2, 1).contiguous()
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w13_bias" in name:
# Handle MLP gate and up projection biases
# Extract gate and up projection bias parts
if use_ep:
narrow_weight = weight[ep_rank_start:ep_rank_end, ...]
else:
narrow_weight = weight[:, 2 * tp_rank_start:2 * tp_rank_end]
param = params_dict[name]
param.copy_(narrow_weight)
loaded_params.add(name)
continue
elif ".w2_bias" in name:
# Handle MLP down projection bias
if use_ep:
weight = weight[ep_rank_start:ep_rank_end, ...]
else:
# (only load on rank 0 to avoid duplication)
if tp_rank != 0:
weight.zero_()
param = params_dict[name]
param.copy_(weight)
loaded_params.add(name)
continue
elif "sinks" in name:
# Handle attention sinks (distributed across ranks)
param = params_dict[name]
narrow_weight = weight.narrow(0, head_start, heads_per_rank)
param.data.copy_(narrow_weight)
loaded_params.add(name)
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, weight)
else:
weight_loader(param, weight, shard_id)
break
else:
# Handle all other weights with potential renaming
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight)
loaded_params.add(name)
return loaded_params
vllm.model_executor.models.gpt_oss.GptOssModel._load_weights_other = GptOssModel_load_weights_other

View File

@@ -0,0 +1,242 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
import torch
import torch_br
from fastcore.basics import patch_to
from transformers import PretrainedConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
# isort: off
from vllm.model_executor.models.intern_vit import (InternMLP,
InternVisionEmbeddings,
InternVisionModel,
InternVisionEncoder)
from vllm.model_executor.models.intern_vit import InternParallelAttention
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.layernorm import RMSNorm
# isort: on
@patch_to(InternVisionModel)
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
"""
[Patch] enable data parallelism for InternVisionModel
"""
super(InternVisionModel, self).__init__()
self.config = config
self.use_data_parallel = use_data_parallel
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(
config=config,
quant_config=None,
num_hidden_layers_override=num_hidden_layers_override,
num_dummy_heads=num_dummy_heads,
prefix=f"{prefix}.encoder",
use_data_parallel=use_data_parallel,
)
@patch_to(InternVisionEmbeddings)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
if self.patch_size == 14:
import torch_br.supa._debug as supa_debug
supa_debug.set_disable_zero_ws(False)
supa_debug.set_disable_zero_output_uma(False)
supa_debug.set_disable_zero_output_numa(False)
supa_debug.set_disable_reorder_zero(False)
patch_embeds = torch_br.supa_conv2d_knxn_snxn_p0x0_fwd(
pixel_values.to(dtype=target_dtype), self.patch_embedding.weight,
self.patch_size, self.patch_size, 0)
if self.patch_embedding.bias is not None:
patch_embeds += self.patch_embedding.bias[None, :, None, None]
supa_debug.set_disable_zero_ws(True)
supa_debug.set_disable_zero_output_uma(True)
supa_debug.set_disable_zero_output_numa(True)
supa_debug.set_disable_reorder_zero(True)
else:
patch_embeds = self.patch_embedding(pixel_values.to(
target_dtype)) # shape = [*, channel, width, height]
batch_size, _, height, width = patch_embeds.shape
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1,
-1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if self.patch_embedding.bias is None:
position_embedding = self._get_position_embedding(height, width)
else:
position_embedding = torch.cat([
self.position_embedding[:, :1, :],
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
width)
],
dim=1)
embeddings = embeddings + position_embedding.to(target_dtype)
return embeddings
@patch_to(InternParallelAttention)
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
num_dummy_heads: int = 0,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
super(InternParallelAttention, self).__init__()
# [Patch] enable data parallelism
self.use_data_parallel = True
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(f'embed_dim must be divisible by num_heads '
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).')
self.tp_size = (1 if use_data_parallel else
get_tensor_model_parallel_world_size())
self.tp_rank = (0
if use_data_parallel else get_tensor_model_parallel_rank())
# Additional dummy heads are used to enable TP for common GPU counts.
self.dummy_dim = (num_dummy_heads + self.num_heads) * self.head_dim
self.num_heads_per_partition = divide(num_dummy_heads + self.num_heads,
self.tp_size)
assert self.tp_size == 1
self.scale = self.head_dim**-0.5
# self.qkv = QKVParallelLinear(
# self.embed_dim,
# self.head_dim,
# num_dummy_heads + self.num_heads,
# bias=config.qkv_bias,
# quant_config=quant_config,
# prefix=f"{prefix}.qkv",
# disable_tp=use_data_parallel,
# )
self.qkv = torch.nn.Linear(self.embed_dim,
3 * self.dummy_dim,
bias=config.qkv_bias)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
self.k_norm = RMSNorm(self.dummy_dim,
eps=config.layer_norm_eps,
var_hidden_size=self.embed_dim)
# self.proj = RowParallelLinear(
# self.dummy_dim,
# self.embed_dim,
# quant_config=quant_config,
# prefix=f"{prefix}.proj",
# disable_tp=use_data_parallel,
# )
self.proj = torch.nn.Linear(self.dummy_dim, self.embed_dim)
# self.attn = MultiHeadAttention(self.num_heads_per_partition,
# self.head_dim, self.scale)
@patch_to(InternParallelAttention)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
x_tmp = []
for i in range(B):
qkv = self.qkv(x[i:i + 1, :]).reshape(1, N, 3, self.num_heads,
C // self.num_heads)
q, k, v = qkv.unbind(
2) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
q = self.q_norm(q.flatten(-2, -1)).view(1, N, self.num_heads,
qkv.shape[4])
k = self.k_norm(k.flatten(-2, -1)).view(1, N, self.num_heads,
qkv.shape[4])
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x0 = attn[:, :, :, :512] @ v[:, :, :512, :]
x1 = attn[:, :, :, 512:] @ v[:, :, 512:, :]
x_tmp.append((x0 + x1).transpose(1, 2).reshape(1, N, C))
x = torch.cat(x_tmp, dim=0)
x = self.proj(x)
return x
@patch_to(InternMLP)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if hidden_states.shape[0] > 1:
output = torch_br._empty_ut_only(hidden_states.shape,
"COLMAJOR",
is_numa=False,
sbp="BB",
axis=0,
dtype=torch.bfloat16)
for i in range(hidden_states.shape[0]):
hidden_states_tmp, _ = self.fc1(hidden_states[i:i + 1, :, :])
hidden_states_tmp = self.activation_fn(hidden_states_tmp)
hidden_states_tmp, _ = self.fc2(hidden_states_tmp)
hidden_states_tmp += self.fc2.bias[None, None, :]
output[i] = hidden_states_tmp[0]
return output
else:
hidden_states, _ = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.fc2(hidden_states)
return hidden_states

View File

@@ -0,0 +1,140 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
from functools import partial
from typing import Optional, Union
import torch
from vllm.distributed import (get_pp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.internlm2 import (InternLM2Attention,
InternLM2MLP, InternLM2Model)
from vllm.sequence import IntermediateTensors
def internlm2_attention_split_qkv(self, qkv: torch.Tensor):
seq_len = qkv.shape[1]
if self.tp_size > 1:
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
qkv = tensor_model_parallel_all_gather(qkv)
qkv = torch.split(qkv, qkv_map, dim=-1)
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
qkv = torch.cat(qkv, dim=-1)
qkv = qkv.view(seq_len, self.total_num_kv_heads, self.key_value_groups + 2,
self.head_dim)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
q = q.reshape(seq_len, self.q_size * self.tp_size).unsqueeze(0)
k = k.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
v = v.reshape(seq_len, self.kv_size * self.tp_size).unsqueeze(0)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
return q, k, v
def internlm2_attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.wo(attn_output)
return output
def internlm2_model_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states = hidden_states.unsqueeze(0)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0) if hidden_states is not None else None,
"residual":
residual.squeeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
def internlm2_mlp_init(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(InternLM2MLP, self).__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.gate_up_proj.no_need_cross = True
self.w2 = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.w2",
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
InternLM2Attention.split_qkv = internlm2_attention_split_qkv
InternLM2Attention.forward = internlm2_attention_forward
InternLM2Model.forward = internlm2_model_forward
InternLM2MLP.__init__ = internlm2_mlp_init

View File

@@ -0,0 +1,367 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch_br
from fastcore.basics import patch_to
from transformers import LlamaConfig
import vllm.model_executor.models.llama
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.llama import (LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM, LlamaModel)
from vllm.model_executor.models.utils import (extract_layer_index,
is_pp_missing_parameter)
from vllm.sequence import IntermediateTensors
from vllm_br import envs
from ..layers.quantization.compressed_tensors.utils import (
get_compressed_tensors_cache_scale)
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
def LlamaDecoderLayer__init__(self,
vllm_config: VllmConfig,
prefix: str = "",
config: Optional[LlamaConfig] = None) -> None:
super(LlamaDecoderLayer, self).__init__()
config = config or vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
tp_size = get_tensor_model_parallel_world_size()
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
# determine whether use qkv merge weights
min_w_gran = 32
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
if is_166 or (config.num_key_value_heads *
(self.hidden_size // config.num_attention_heads)
>= tp_size * spc_num * min_w_gran):
self.self_attn = AttentionSplit(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
else:
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loaded_params = []
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
# determine whether is qkv merge weights
qkv_merge = False
for key in params_dict:
if "qkv_proj" in key:
qkv_merge = True
break
if not qkv_merge and len(stacked_params_mapping) >= 3:
stacked_params_mapping = stacked_params_mapping[3:]
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
if scale_name := get_compressed_tensors_cache_scale(name):
# Loading kv cache scales for compressed-tensors quantization
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = loaded_weight[0]
weight_loader(param, loaded_weight)
loaded_params.append(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
# weight layout infer
param.data = param.data + 0
loaded_params.append(name)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# weight layout infer
param.data = param.data + 0
if name.find("norm.weight") != -1:
param.data = param.data.to(torch.float32)
loaded_params.append(name)
return set(loaded_params)
def llamamodel_forward(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors, tuple[torch.Tensor,
list[torch.Tensor]]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
hidden_states = hidden_states.unsqueeze(0)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states = hidden_states.unsqueeze(0)
residual = residual.unsqueeze(0)
aux_hidden_states = []
for idx, layer in enumerate(self.layers[self.start_layer:self.end_layer]):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0)
if hidden_states is not None else hidden_states,
"residual":
residual.squeeze(0) if residual is not None else residual
})
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states.squeeze(0)
def LlamaAttention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
q, k, v = torch_br.split_w_sbp_infer(
qkv, [self.q_size, self.kv_size, self.kv_size])
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
@patch_to(LlamaAttention)
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
attn_type: str = AttentionType.DECODER,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None) -> None:
super(LlamaAttention, self).__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self.head_dim = getattr(config, "head_dim",
self.hidden_size // self.total_num_heads)
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
qconfig = None
if quant_config is not None and quant_config.qkv_quantized:
qconfig = quant_config
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=qconfig,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
vllm.model_executor.models.llama.LlamaDecoderLayer.__init__ = LlamaDecoderLayer__init__
LlamaForCausalLM.load_weights = load_weights
LlamaModel.forward = llamamodel_forward
LlamaAttention.forward = LlamaAttention_forward

View File

@@ -0,0 +1,349 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import gc
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch_br
from transformers import Qwen2Config
import vllm.model_executor.models.qwen2
from vllm.attention import AttentionType
from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen2 import (Qwen2Attention,
Qwen2DecoderLayer, Qwen2Model)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
#import vllm.envs as envs
from vllm_br import envs
from .supa_module import AttentionSplit, MergedGateUpMLPSiluL2
def Qwen2DecoderLayer__init__(
self,
config: Qwen2Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen2DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
dual_chunk_attention_config = getattr(config,
"dual_chunk_attention_config", None)
# By default, Qwen2 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
attention_bias = getattr(config, "attention_bias", True) or getattr(
config, "bias", True)
tp_size = get_tensor_model_parallel_world_size()
spc_num = torch_br.supa.get_device_properties("supa").max_compute_units
# determine whether use qkv merge weights
min_w_gran = 32
is_166 = envs.VLLM_BR_DEVICE_SPC_NUM > 16
# NOTE: current br166 don't support s(2)b split, so br166 can only use AttentionSplit
if is_166 or (config.num_key_value_heads *
(self.hidden_size // config.num_attention_heads)
>= tp_size * spc_num * min_w_gran):
self.self_attn = AttentionSplit(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
bias=attention_bias,
)
logger.debug('[Patch] Use AttentionSplit instead of Qwen2Attention')
else:
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
dual_chunk_attention_config=dual_chunk_attention_config,
)
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
self.platform = 0
if spc_num > 16:
self.platform = 1
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
logger.info('[Patch] Qwen2 MLP do not merge up/gate weight')
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
qkv_merge = False
for key in params_dict:
if "qkv_proj" in key:
qkv_merge = True
break
if not qkv_merge and len(stacked_params_mapping) >= 3:
stacked_params_mapping = stacked_params_mapping[3:]
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None
and (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight
if loaded_weight.dim() == 0 else loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if self.platform == 0:
param.data = param.data + 0
if name.find("norm.weight") != -1:
if self.platform == 1:
w_cpu = param.data.to(torch.float32).cpu()
w_supa = torch_br._empty_ut_only(w_cpu.shape,
dtype=w_cpu.dtype,
is_numa=False,
device=param.data.device,
tensor_type="linear_bias",
axis=0,
sbp="BB")
w_supa.copy_(w_cpu)
param.data = w_supa
else:
param.data = param.data.to(torch.float32)
if name.find("embed_tokens.weight") != -1 and self.platform == 1:
w_shape = param.data.shape
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
dtype=param.data.dtype,
is_numa=False,
device=param.data.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
w_supa.copy_(param.data.cpu())
param.data = w_supa
if name.find("lm_head.weight") != -1 and self.platform == 1:
w_shape = param.data.shape
w_supa = torch_br._empty_ut_only(size=(w_shape[0], w_shape[1]),
dtype=param.data.dtype,
is_numa=False,
device=param.data.device,
tensor_type="colmajor",
axis=0,
sbp="SB")
w_supa.copy_(param.data.cpu())
param.data = w_supa
loaded_params.add(name)
# inference rope sin_cos layout
for _, module in self.named_modules():
rotary_emb = getattr(module, "rotary_emb", None)
if rotary_emb is not None:
if self.platform == 1:
if isinstance(rotary_emb, MRotaryEmbedding):
w_shape = rotary_emb.cos_sin_cache.shape
cos_sin_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.cos_sin_cache.dtype,
is_numa=False,
device=rotary_emb.cos_sin_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
cos_sin_supa.copy_(rotary_emb.cos_sin_cache.cpu())
rotary_emb.cos_sin_cache = cos_sin_supa
else:
w_shape = rotary_emb.sin_cache.shape
sin_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.sin_cache.dtype,
is_numa=False,
device=rotary_emb.sin_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
sin_supa.copy_(rotary_emb.sin_cache.cpu())
rotary_emb.sin_cache = sin_supa
cos_supa = torch_br._empty_ut_only(
size=(w_shape[0], w_shape[1]),
dtype=rotary_emb.cos_cache.dtype,
is_numa=False,
device=rotary_emb.cos_cache.device,
tensor_type="colmajor",
axis=0,
sbp="BB")
cos_supa.copy_(rotary_emb.cos_cache.cpu())
rotary_emb.cos_cache = cos_supa
else:
if isinstance(rotary_emb, MRotaryEmbedding):
rotary_emb.cos_sin_cache = rotary_emb.cos_sin_cache + 0
else:
rotary_emb.sin_cache = rotary_emb.sin_cache + 0
rotary_emb.cos_cache = rotary_emb.cos_cache + 0
torch.supa.synchronize()
gc.collect()
torch.supa.empty_cache()
return loaded_params
def model_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
# NOTE: supa wants 3d shape for llm
if len(hidden_states.shape) == 2:
hidden_states = hidden_states.unsqueeze(0)
for layer in self.layers[self.start_layer:self.end_layer]:
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0) if hidden_states is not None else None,
"residual":
residual.squeeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
# NOTE: convert back to 2D
hidden_states = hidden_states.squeeze()
if hidden_states.dim() == 1:
hidden_states = hidden_states.unsqueeze(0)
return hidden_states
def Qwen2Attention_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
q, k, v = torch_br.split_w_sbp_infer(
qkv, [self.q_size, self.kv_size, self.kv_size])
else:
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
vllm.model_executor.models.qwen2.Qwen2DecoderLayer.__init__ = Qwen2DecoderLayer__init__
logger.debug('[Patch] patch Qwen2 MLP with LlaMA_MLP_SiLU_3L')
Qwen2Model.load_weights = load_weights
Qwen2Model.forward = model_forward
Qwen2Attention.forward = Qwen2Attention_forward

View File

@@ -0,0 +1,530 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2.5-VL model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable
from functools import partial
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_br
from einops import rearrange
from fastcore.basics import patch_to
import vllm
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2_5_vl import (Qwen2_5_VisionBlock,
Qwen2_5_VisionMLP,
Qwen2_5_VisionPatchMerger,
Qwen2_5_VisionTransformer)
from vllm.model_executor.models.qwen2_vl import apply_rotary_pos_emb_vision
from vllm.model_executor.models.utils import cast_overflow_tensors
from vllm.platforms import _Backend
from vllm_br import envs
from .br_utils import convBB, convSB
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
"""All-gather the input tensor interleavely across model parallel group."""
import torch.distributed as dist
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
dist.all_gather(gathered_tensors,
local_tensor,
group=parallel_state.get_tp_group().device_group)
gathered_tensors_split = [
torch.split(tensor, hidden_size // tp_size, -1)
for tensor in gathered_tensors
]
ordered_tensors = [
tensor for pair in zip(*gathered_tensors_split, strict=False)
for tensor in pair
]
result_tensor = torch.cat(ordered_tensors, dim=-1)
return result_tensor
class Qwen2_5_VisionAttention_fit(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
use_upstream_fa: bool = False,
) -> None:
super().__init__()
# Per attention head and per partition values.
self.tp_size = (1 if use_data_parallel else
parallel_state.get_tensor_model_parallel_world_size())
self.tp_rank = parallel_state.get_tensor_model_parallel_rank()
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, self.tp_size)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.qkv = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.hidden_size_per_attention_head,
total_num_heads=num_heads,
total_num_kv_heads=num_heads,
bias=True,
quant_config=quant_config,
prefix=f"{prefix}.qkv")
self.proj = RowParallelLinear(input_size=projection_size,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.proj")
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, width = qkv.shape
qkv = qkv.reshape(-1, width)
if self.tp_size > 1:
qkv = all_gather_interleave(qkv, self.qkv.hidden_size,
self.tp_size)
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=-1)
# 3 * [s, b, head * head_dim]
if self.tp_size > 1:
splitter = partial(dist_utils.split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
v = splitter(v)[self.tp_rank]
# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v
def transform_qkv_shape(self,
qkv_layer,
cur_qkv_shape_state,
obj_qkv_shape_state,
obj_shape=None):
if obj_qkv_shape_state == "bn_s_h":
if cur_qkv_shape_state == "bn_s_h":
return qkv_layer
if cur_qkv_shape_state == "b_s_n_h":
# [b, sq, np or nkvp, hn] --> [b, np or nkvp, sq, hn] --> [b*(np or nkvp), sq, hn]
qkv_layer = qkv_layer.permute(0, 2, 1, 3)
# view 4d matrix to 3d matrix, TODO: use fused_split_view here
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
qkv_layer.size(3)).contiguous()
return qkv_layer
if cur_qkv_shape_state == "b_n_s_h":
qkv_layer = qkv_layer.reshape(-1, qkv_layer.size(2),
qkv_layer.size(3))
return qkv_layer
if obj_qkv_shape_state == "b_n_s_h":
if cur_qkv_shape_state == "b_n_s_h":
return qkv_layer
if cur_qkv_shape_state == "bn_s_h":
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
qkv_layer.size(1),
qkv_layer.size(2))
return qkv_layer
if cur_qkv_shape_state == "b_s_n_h":
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
if obj_qkv_shape_state == "b_s_n_h":
if cur_qkv_shape_state == "b_s_n_h":
return qkv_layer
if cur_qkv_shape_state == "b_n_s_h":
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
if cur_qkv_shape_state == "bn_s_h":
qkv_layer = qkv_layer.reshape(obj_shape[0], -1,
qkv_layer.size(1),
qkv_layer.size(2))
qkv_layer = qkv_layer.permute(0, 2, 1, 3).contiguous()
return qkv_layer
AssertionError(
f"unsupported shape transform, ori:{cur_qkv_shape_state} obj:{obj_qkv_shape_state}"
)
def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
mask: torch.Tensor = None,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
x = convBB(x)
seql = x.shape[-2]
x = x.reshape(seql, 2, 3,
-1).permute(0, 2, 1,
3).contiguous().reshape(1, seql, -1)
if x.shape[0] == 1:
x = x.permute(1, 0, 2).contiguous()
# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
if rotary_pos_emb is not None:
q = apply_rotary_pos_emb_vision(
q,
rotary_pos_emb,
)
k = apply_rotary_pos_emb_vision(
k,
rotary_pos_emb,
)
# q, k, v: [b, s, n, h] -> reshape: [b, n, s, h] -> reshape: [b * n, s, h]
q = q.permute(0, 2, 1, 3).contiguous()
k = k.permute(0, 2, 1, 3).contiguous()
v = v.permute(0, 2, 1, 3).contiguous()
q = self.transform_qkv_shape(q, "b_n_s_h", "bn_s_h")
k = self.transform_qkv_shape(k, "b_n_s_h", "bn_s_h")
v = self.transform_qkv_shape(v, "b_n_s_h", "bn_s_h")
#TODO(qingqi), skip sueager bug, when sueager op fix the bug,remove the code
if q.shape[1] == 8192 or q.shape[1] == 8424 or q.shape[1] == 8464:
mask = mask.to(torch.bfloat16)
context_layer, _ = torch_br.sueager_scaled_dot_product_attention_fwd(
query=q,
key=k,
value=v,
mask=mask,
dropout_prob=0.0,
is_causal=False,
scale=1 / self.norm_factor,
algorithm="FMHA",
)
# reshape attn out: [b*n, s, h] -> [s, b, h*n]
context_layer = torch_br.supa_shape_transform_qkv(
context_layer, 1, context_layer.shape[-2],
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, False, False, None)
if context_layer.shape[0] != 1:
context_layer = context_layer.permute(1, 0, 2).contiguous()
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
context_layer = convSB(context_layer, -1)
output, _ = self.proj(context_layer)
return output
def vision_block_forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
mask: torch.Tensor = None,
) -> torch.Tensor:
if x.shape[0] != 1:
x = x.permute(1, 0, 2).contiguous()
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens,
mask=mask)
x = x + self.mlp(self.norm2(x))
return x
class Qwen2_5_VisionPatchEmbed_fit(nn.Module):
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
self.proj = ColumnParallelLinear(in_channels * temporal_patch_size *
patch_size * patch_size,
hidden_size,
bias=False,
gather_output=True,
quant_config=quant_config,
prefix="")
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(0)
L, _ = x.shape[-2], x.shape[-1]
x = self.proj(x)[0].view(L, self.hidden_size)
return x
@patch_to(vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer)
def gen_normal_mask(self, cu_seqlens, grid_thw, device):
# NOTE: for mask-mock-pack, we precompute mask and store in PackedSeqParams
seq_len = max(cu_seqlens)
attention_mask = torch.full([1, seq_len, seq_len],
1,
dtype=torch.int32,
device=device)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1]:cu_seqlens[i],
cu_seqlens[i - 1]:cu_seqlens[i]] = 0
return attention_mask
def vision_transformer_forward(
self,
x: torch.Tensor,
grid_thw: list[list[int]],
) -> torch.Tensor:
# patchify
seq_len, _ = x.size()
rotary_pos_emb_list = []
window_index_list: list = []
cu_window_seqlens_list: list = [
torch.tensor([0], dtype=torch.int32, device="cpu")
]
cu_seqlens_list: list = []
hidden_states = x.to(device=self.device, dtype=self.dtype)
hidden_states = self.patch_embed(hidden_states)
window_index_id = 0
cu_window_seqlens_last = 0
for t, h, w in grid_thw:
t, h, w = int(t), int(h), int(w)
llm_h = h // self.spatial_merge_size
llm_w = w // self.spatial_merge_size
(
rotary_pos_emb_thw,
window_index_thw,
cu_seqlens_window_thw,
cu_seqlens_thw,
) = self.get_rope_by_thw(t, h, w)
window_index_list.append(window_index_thw + window_index_id)
window_index_id += (t * llm_h * llm_w)
cu_seqlens_window_thw = (cu_seqlens_window_thw +
cu_window_seqlens_last)
cu_window_seqlens_last = cu_seqlens_window_thw[-1]
cu_window_seqlens_list.append(cu_seqlens_window_thw)
rotary_pos_emb_list.append(rotary_pos_emb_thw)
cu_seqlens_list.append(cu_seqlens_thw)
rotary_pos_emb = torch.cat(rotary_pos_emb_list)
window_index = torch.cat(window_index_list)
cu_window_seqlens = torch.cat(cu_window_seqlens_list)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
cu_seqlens = torch.cat(cu_seqlens_list)
cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
# transformers
# pre-compute seqlens for window/full attn to reduce cuMemcpy operations
max_seqlen_full, seqlens_full = self.compute_attn_mask_seqlen(cu_seqlens)
max_seqlen_window, seqlens_window = self.compute_attn_mask_seqlen(
cu_window_seqlens)
cu_seqlens = cu_seqlens.to(device=self.device, non_blocking=True)
cu_window_seqlens = cu_window_seqlens.to(device=self.device,
non_blocking=True)
rotary_pos_emb = rotary_pos_emb.to(device=self.device, non_blocking=True)
window_index = window_index.to(device=hidden_states.device,
non_blocking=True)
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit,
self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
hidden_states = hidden_states.unsqueeze(1)
attention_mask = self.gen_normal_mask(cu_seqlens, grid_thw, x.device)
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
max_seqlen_now = max_seqlen_full
seqlens_now = seqlens_full
else:
cu_seqlens_now = cu_window_seqlens
max_seqlen_now = max_seqlen_window
seqlens_now = seqlens_window
hidden_states = blk(hidden_states,
cu_seqlens=cu_seqlens_now,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen_now,
seqlens=seqlens_now,
mask=attention_mask)
# For Qwen2.5-VL-3B, float16 will overflow at last block
# for long visual tokens sequences.
if hidden_states.dtype == torch.float16:
hidden_states = cast_overflow_tensors(hidden_states)
# adapter
hidden_states = self.merger(hidden_states).squeeze(0)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states
def vision_transformer_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if name == 'patch_embed.proj.weight':
loaded_weight = loaded_weight.reshape(loaded_weight.shape[0],
-1).contiguous()
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
def Qwen2_5_VisionPatchMerger_forward_fit(self,
x: torch.Tensor) -> torch.Tensor:
x = self.ln_q(x)
x = x.view(-1, self.hidden_size).unsqueeze(0)
out = self.mlp(x)
return out
def Qwen2_5_VisionMLP__init__(
self,
in_features: int,
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
use_data_parallel: bool = False):
super(Qwen2_5_VisionMLP, self).__init__()
self.gate_proj = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj")
self.up_proj = ColumnParallelLinear(in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
disable_tp=use_data_parallel)
self.act_fn = F.silu
def Qwen2_5_VisionMLP_forward(self, x: torch.Tensor):
x_gate, _ = self.gate_proj(x)
x_gate = self.act_fn(x_gate)
x_up, _ = self.up_proj(x)
x_down, _ = self.down_proj(x_gate * x_up)
return x_down
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionAttention = Qwen2_5_VisionAttention_fit
vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionPatchEmbed = Qwen2_5_VisionPatchEmbed_fit
Qwen2_5_VisionBlock.forward = vision_block_forward
Qwen2_5_VisionTransformer.forward = vision_transformer_forward
Qwen2_5_VisionTransformer.load_weights = vision_transformer_load_weights
Qwen2_5_VisionPatchMerger.forward = Qwen2_5_VisionPatchMerger_forward_fit
Qwen2_5_VisionMLP.__init__ = Qwen2_5_VisionMLP__init__
Qwen2_5_VisionMLP.forward = Qwen2_5_VisionMLP_forward

View File

@@ -0,0 +1,47 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from collections.abc import Mapping
from vllm.model_executor.models.qwen2_vl import (Qwen2VLDummyInputsBuilder,
Qwen2VLProcessingInfo)
from vllm.multimodal.parse import ImageSize
def get_image_size_with_most_features(self) -> ImageSize:
"""This function is used in Qwen2_VL, Qwen2_5_VL, patch it in qwen2_vl.py"""
max_image_size, _ = self._get_vision_info(
image_width=240,
image_height=240,
image_processor=None,
)
return max_image_size
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = 1
num_videos = 0
hf_processor = self.info.get_hf_processor()
image_token: str = hf_processor.image_token
video_token: str = hf_processor.video_token
return image_token * num_images + video_token * num_videos
Qwen2VLProcessingInfo.get_image_size_with_most_features = (
get_image_size_with_most_features)
Qwen2VLDummyInputsBuilder.get_dummy_text = get_dummy_text

View File

@@ -0,0 +1,254 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from collections.abc import Iterable
from typing import Optional
import torch
import torch_br
from fastcore.basics import patch_to
from transformers import Qwen3Config
import vllm.model_executor.models.qwen3
from vllm.attention import AttentionType
from vllm.config import CacheConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen3 import (Qwen3Attention,
Qwen3DecoderLayer, Qwen3Model)
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm_br.v1.attention.backends.attention_v1 import (
SUPAFlashAttentionMetadata)
from .qwen2 import model_forward
from .supa_module import MergedGateUpMLPSiluL2
@patch_to(vllm.model_executor.models.qwen3.Qwen3Attention)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
## for dummy run
return hidden_states
seq_len = hidden_states.shape[-2]
decode_seql = 512
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
if kv_cache is not None:
if seq_len <= decode_seql:
if hasattr(self.qkv_proj, "qweight"):
qkv_weight = self.qkv_proj.qweight.data
qkv_scales = self.qkv_proj.scales.data
elif hasattr(self.qkv_proj, "weight_packed"):
qkv_weight = self.qkv_proj.weight_packed.data
qkv_scales = self.qkv_proj.weight_scale.data
else:
qkv_weight = self.qkv_proj.weight
qkv_scales = None
if isinstance(self.rotary_emb, MRotaryEmbedding):
assert len(
self.rotary_emb.mrope_section
) == 3 and self.rotary_emb.mrope_section[
1] == self.rotary_emb.mrope_section[
2], "current only support mrope_section width and height are equal!"
q, k, v = torch_br.br_qwen3_vl_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.cos_sin_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
self.rotary_emb.mrope_section[1],
bias=self.qkv_proj.bias,
scales=qkv_scales)
else:
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
bias=self.qkv_proj.bias,
scales=qkv_scales)
else:
qkv, _ = self.qkv_proj(hidden_states)
if isinstance(self.rotary_emb, MRotaryEmbedding):
assert len(
self.rotary_emb.mrope_section
) == 3 and self.rotary_emb.mrope_section[
1] == self.rotary_emb.mrope_section[
2], "current only support mrope_section width and height are equal!"
q, k, v = torch_br.br_fused_rms_mrope_kvstore_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim, self.q_norm.variance_epsilon,
self.q_norm.weight, self.k_norm.weight,
self.rotary_emb.cos_sin_cache, kv_cache, positions,
attn_metadata.slot_mapping, attn_metadata.block_table,
attn_metadata.query_start_loc, attn_metadata.context_lens,
self.rotary_emb.mrope_section[1])
else:
q, k, v = torch_br.br_fused_rms_rope_kvstore_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim, self.q_norm.variance_epsilon,
self.q_norm.weight, self.k_norm.weight,
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
kv_cache, positions, attn_metadata.slot_mapping,
attn_metadata.block_table, attn_metadata.query_start_loc,
attn_metadata.context_lens)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = False
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
else:
return hidden_states
def Qwen3DecoderLayer__init__(
self,
config: Qwen3Config,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super(Qwen3DecoderLayer, self).__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
rope_scaling = getattr(config, "rope_scaling", None)
# By default, Qwen3 uses causal attention as it is a decoder-only model.
# You can override the HF config with `is_causal=False` to enable
# bidirectional attention, which is used in some embedding models
# (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
if getattr(config, "is_causal", True):
attn_type = AttentionType.DECODER
else:
attn_type = AttentionType.ENCODER_ONLY
self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False),
head_dim=getattr(config, 'head_dim', None),
cache_config=cache_config,
quant_config=quant_config,
rope_scaling=rope_scaling,
prefix=f"{prefix}.self_attn",
attn_type=attn_type,
)
self.mlp = MergedGateUpMLPSiluL2(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None
and (scale_name := self.quant_config.get_cache_scale(name))):
# Loading kv cache quantization scales
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
loaded_weight = (loaded_weight
if loaded_weight.dim() == 0 else loaded_weight[0])
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.find("norm.weight") != -1:
param.data = param.data.to(torch.float32)
loaded_params.add(name)
return loaded_params
vllm.model_executor.models.qwen3.Qwen3DecoderLayer.__init__ = Qwen3DecoderLayer__init__
logger.debug('[Patch] patch Qwen3 MLP with MergedGateUpMLPSiluL2')
Qwen3Model.load_weights = load_weights
Qwen3Model.forward = model_forward

View File

@@ -0,0 +1,300 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch_br
from fastcore.basics import patch_to
import vllm
from vllm.distributed import get_pp_group, tensor_model_parallel_all_reduce
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_moe import Qwen3MoeModel
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
from vllm_br.v1.attention.backends.attention_v1 import (
SUPAFlashAttentionMetadata)
logger = init_logger(__name__)
@patch_to(vllm.model_executor.models.qwen3_moe.Qwen3MoeSparseMoeBlock)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
orig_shape = hidden_states.shape
if len(hidden_states.shape) == 3:
hidden_states = hidden_states.squeeze(0)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=(self.gate.weight, None,
None))
if hasattr(final_hidden_states, 'all_reduced'):
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
delattr(final_hidden_states, 'all_reduced')
elif self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)
@patch_to(vllm.model_executor.models.qwen3_moe.Qwen3MoeAttention)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata: SUPAFlashAttentionMetadata = forward_context.attn_metadata
if attn_metadata is None:
## for dummy run
return hidden_states
seq_len = hidden_states.shape[-2]
decode_seql = 512
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
if kv_cache is not None:
if seq_len <= decode_seql:
if hasattr(self.qkv_proj, "qweight"):
qkv_weight = self.qkv_proj.qweight.data
qkv_scales = self.qkv_proj.scales.data
elif hasattr(self.qkv_proj, "weight_packed"):
qkv_weight = self.qkv_proj.weight_packed.data
qkv_scales = self.qkv_proj.weight_scale.data
else:
qkv_weight = self.qkv_proj.weight
qkv_scales = None
if isinstance(self.rotary_emb, MRotaryEmbedding):
assert len(
self.rotary_emb.mrope_section
) == 3 and self.rotary_emb.mrope_section[
1] == self.rotary_emb.mrope_section[
2], "current only support mrope_section width and height are equal!"
q, k, v = torch_br.br_qwen3_vl_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.cos_sin_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
self.rotary_emb.mrope_section[1],
bias=self.qkv_proj.bias,
scales=qkv_scales)
else:
q, k, v = torch_br.br_qwen3_prefix_attn_infer(
hidden_states,
qkv_weight, [self.q_size, self.kv_size, self.kv_size],
self.head_dim,
self.q_norm.variance_epsilon,
self.q_norm.weight,
self.k_norm.weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
bias=self.qkv_proj.bias,
scales=qkv_scales)
else:
qkv, _ = self.qkv_proj(hidden_states)
if isinstance(self.rotary_emb, MRotaryEmbedding):
assert len(
self.rotary_emb.mrope_section
) == 3 and self.rotary_emb.mrope_section[
1] == self.rotary_emb.mrope_section[
2], "current only support mrope_section width and height are equal!"
q, k, v = torch_br.br_fused_rms_mrope_kvstore_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim, self.q_norm.variance_epsilon,
self.q_norm.weight, self.k_norm.weight,
self.rotary_emb.cos_sin_cache, kv_cache, positions,
attn_metadata.slot_mapping, attn_metadata.block_table,
attn_metadata.query_start_loc, attn_metadata.context_lens,
self.rotary_emb.mrope_section[1])
else:
q, k, v = torch_br.br_fused_rms_rope_kvstore_infer(
qkv, [self.q_size, self.kv_size, self.kv_size],
self.head_dim, self.q_norm.variance_epsilon,
self.q_norm.weight, self.k_norm.weight,
self.rotary_emb.sin_cache, self.rotary_emb.cos_cache,
kv_cache, positions, attn_metadata.slot_mapping,
attn_metadata.block_table, attn_metadata.query_start_loc,
attn_metadata.context_lens)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = False
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
else:
return hidden_states
def model_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
if len(hidden_states.shape) == 2:
hidden_states = hidden_states.unsqueeze(0)
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.squeeze(0) if hidden_states is not None else None,
"residual":
residual.squeeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
# NOTE: convert back to 2D
hidden_states = hidden_states.squeeze()
if hidden_states.dim() == 1:
hidden_states = hidden_states.unsqueeze(0)
return hidden_states
Qwen3MoeModel.forward = model_forward
def Qwen3MoeModel_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if ((name.endswith(".bias") or name.endswith("_bias"))
and name not in params_dict):
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
logger.warning_once(
"Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
name,
remapped_kv_scale_name,
)
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if name.find("norm.weight") != -1:
param.data = param.data.to(torch.float32)
loaded_params.add(name)
return loaded_params
Qwen3MoeModel.load_weights = Qwen3MoeModel_load_weights

View File

@@ -0,0 +1,207 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen3_vl import (Qwen3_VisionBlock,
Qwen3_VisionPatchEmbed,
Qwen3_VisionTransformer,
Qwen3LLMModel)
from vllm.sequence import IntermediateTensors
from vllm_br import envs
from .br_utils import convBB
def Qwen3_VisionPatchEmbed__init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
hidden_size: int = 1152,
) -> None:
super(Qwen3_VisionPatchEmbed, self).__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.hidden_size = hidden_size
self.proj = ReplicatedLinear(in_channels * temporal_patch_size *
patch_size * patch_size,
hidden_size,
bias=True,
prefix="")
Qwen3_VisionPatchEmbed.__init__ = Qwen3_VisionPatchEmbed__init__
def Qwen3_VisionPatchEmbed_forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.unsqueeze(0)
L, _ = x.shape[-2], x.shape[-1]
x = self.proj(x)[0].view(L, self.hidden_size)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
x = convBB(x)
return x
Qwen3_VisionPatchEmbed.forward = Qwen3_VisionPatchEmbed_forward
def Qwen3_VisionBlock_forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
) -> torch.Tensor:
if x.shape[0] != 1:
x = x.permute(1, 0, 2).contiguous()
x = x + self.attn(self.norm1(x),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
max_seqlen=max_seqlen,
seqlens=seqlens)
x = x + self.mlp(self.norm2(x))
return x
Qwen3_VisionBlock.forward = Qwen3_VisionBlock_forward
def Qwen3_VisionTransformer_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("attn.qkv.", "attn.q.", "q"),
("attn.qkv.", "attn.k.", "k"),
("attn.qkv.", "attn.v.", "v"),
]
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if name == 'patch_embed.proj.weight':
loaded_weight = loaded_weight.reshape(loaded_weight.shape[0],
-1).contiguous()
weight_loader(param, loaded_weight)
if name.find("norm.weight") != -1:
param.data = param.data.to(torch.float32)
loaded_params.add(name)
return loaded_params
Qwen3_VisionTransformer.load_weights = Qwen3_VisionTransformer_load_weights
def Qwen3LLMModel_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
# args for deepstack
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states = hidden_states.unsqueeze(0)
residual = residual.unsqueeze(0) if residual is not None else None
for layer_idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and \
layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[
f"deepstack_input_embeds_{layer_idx}"].to(
hidden_states.device).unsqueeze(0)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.unsqueeze(0),
"residual":
residual.unsqueeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
Qwen3LLMModel.forward = Qwen3LLMModel_forward

View File

@@ -0,0 +1,258 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-VL MOE model compatible with HuggingFace weights."""
import typing
from collections.abc import Iterable
from typing import Callable, Optional, Union
import torch
from vllm.distributed import get_pp_group
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.qwen3_vl_moe import Qwen3MoeLLMModel
from vllm.model_executor.models.utils import is_pp_missing_parameter
from vllm.sequence import IntermediateTensors
def Qwen3MoeLLMModel_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
hidden_states = hidden_states.unsqueeze(0)
residual = residual.unsqueeze(0) if residual is not None else None
for layer_idx, layer in enumerate(
self.layers[self.start_layer:self.end_layer]):
layer_idx = layer_idx + self.start_layer
hidden_states, residual = layer(
positions,
hidden_states,
residual,
)
if deepstack_input_embeds is not None and \
layer_idx in range(0, len(deepstack_input_embeds)):
hidden_states = hidden_states + deepstack_input_embeds[
f"deepstack_input_embeds_{layer_idx}"].to(
hidden_states.device).unsqueeze(0)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states":
hidden_states.unsqueeze(0),
"residual":
residual.unsqueeze(0) if residual is not None else None
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states.squeeze(0)
Qwen3MoeLLMModel.forward = Qwen3MoeLLMModel_forward
def Qwen3MoeLLMModel_load_weights(
self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
# Skip loading extra parameters for GPTQ/modelopt models.
ignore_suffixes = (".bias", "_bias", ".k_scale", "_k_scale", ".v_scale",
"_v_scale", ".weight_scale", "_weight_scale",
".input_scale", "_input_scale")
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
expert_params_mapping = self.get_expert_mapping()
is_fused_expert = False
fused_expert_params_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
num_experts = self.config.num_experts
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if ("experts.gate_up_proj" in name or "experts.down_proj" in name):
is_fused_expert = True
expert_params_mapping = fused_expert_params_mapping
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
if name.endswith("scale"):
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if weight_loader == default_weight_loader:
weight_loader(param, loaded_weight)
else:
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
if is_fused_expert:
loaded_weight = loaded_weight.transpose(-1, -2) # no bias
if "experts.gate_up_proj" in name:
loaded_weight = loaded_weight.chunk(2, dim=-2)
success_w1 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[0], "w1",
num_experts)
success_w3 = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight[1], "w3",
num_experts)
success = success_w1 and success_w3
else:
# down_proj
success = self.load_fused_expert_weights(
name_mapped, params_dict, loaded_weight, shard_id,
num_experts)
else:
# Skip loading extra parameters for GPTQ/modelopt models
if name_mapped.endswith(
ignore_suffixes
) and name_mapped not in params_dict:
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(Callable[..., bool],
param.weight_loader)
success = weight_loader(param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True)
if success:
name = name_mapped
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra parameters for GPTQ/modelopt models.
if name.endswith(ignore_suffixes) and name not in params_dict:
continue
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
# logger.warning_once(
# "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501
# name,
# remapped_kv_scale_name,
# )
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
if name == 'patch_embed.proj.weight':
loaded_weight = loaded_weight.reshape(
loaded_weight.shape[0], -1).contiguous()
weight_loader(param, loaded_weight)
if name.find("norm.weight") != -1:
param.data = param.data.to(torch.float32)
loaded_params.add(name)
return loaded_params
Qwen3MoeLLMModel.load_weights = Qwen3MoeLLMModel_load_weights

View File

@@ -0,0 +1,27 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from vllm import ModelRegistry
from vllm.model_executor.models.registry import _MULTIMODAL_MODELS
#from .glm4_1v import Glm4vForConditionalGeneration
_MULTIMODAL_MODELS["Glm4vForConditionalGeneration"] = (
"glm4_1v", "Glm4vForConditionalGeneration")
ModelRegistry.register_model(
"Glm4vForConditionalGeneration",
"vllm_br.model_executor.models.glm4_1v:Glm4vForConditionalGeneration")

View File

@@ -0,0 +1,89 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import torch
# Adapted from transformers
from fastcore.basics import patch_to
import vllm
from vllm.model_executor.models.roberta import (
create_position_ids_from_input_ids)
@patch_to(vllm.model_executor.models.roberta.RobertaClassificationHead)
def forward(self, features, **kwargs):
x = features[0, :] # take <s> token (equiv. to [CLS])
x = x.unsqueeze(0) # add batch dimension
x = self.dense(x)
x = torch.tanh(x)
x = self.out_proj(x)
x = x.squeeze(0) # remove batch dimension
return x
@patch_to(vllm.model_executor.models.roberta.RobertaEmbedding)
def forward(
self,
input_ids: torch.Tensor,
seq_lens: torch.Tensor,
position_ids: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
input_ids = input_ids.squeeze(0) # notice here input_ids is 2-dim tensor
input_shape = input_ids.size()
inputs_embeds = self.word_embeddings(input_ids)
# Replace position ids because in RoBERTa models
# they have to start at padding_idx + 1 and ignore
# existing padding tokens
# References:
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
pos_list = []
token_list = []
offset = 0
for seq_len in seq_lens:
pos_list.append(position_ids[offset:offset + seq_len])
token_list.append(input_ids[offset:offset + seq_len])
offset += seq_len
new_pos_list = []
for positions, tokens in zip(pos_list, token_list, strict=False):
# Verify assumption that incoming position are
# always a sequence from 0 to N.
expected_pos = torch.arange(positions.size()[0],
dtype=torch.long,
device=inputs_embeds.device)
assert torch.equal(positions, expected_pos)
new_pos_list.append(
create_position_ids_from_input_ids(tokens, self.padding_idx))
position_ids = torch.cat(new_pos_list)
# Position embeddings.
position_embeddings = self.position_embeddings(position_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape,
dtype=torch.long,
device=inputs_embeds.device)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
return embeddings.unsqueeze(0) # add batch dimension for BR attention

View File

@@ -0,0 +1,25 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from .attention import AttentionSplit
from .mla import SupaMLAModules, SupaMultiHeadLatentAttention
from .mlp import LlamaMlpSiluL3, MergedGateUpMLPSiluL2
from .moe import DeepseekV2MoE
__all__ = [
'LlamaMlpSiluL3', 'AttentionSplit', 'MergedGateUpMLPSiluL2',
'DeepseekV2MoE', 'SupaMLAModules', 'SupaMultiHeadLatentAttention'
]

View File

@@ -0,0 +1,206 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Any, Optional, Tuple
import torch
import torch_br
from torch import nn
from torch_br.supa.profiler_kineto import record_function
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import (MRotaryEmbedding,
get_rope)
from vllm.model_executor.models.utils import extract_layer_index
class AttentionSplit(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: int = 10000,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
rope_scaling: Optional[Tuple] = None,
attn_type: str = AttentionType.DECODER,
prefix: str = "",
dual_chunk_attention_config: Optional[dict[str, Any]] = None,
bias: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
qconfig = None
if quant_config is not None and quant_config.qkv_quantized:
qconfig = quant_config
self.q_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=self.q_size * tp_size,
bias=bias,
quant_config=qconfig,
prefix=f"{prefix}.q_proj")
self.k_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=self.kv_size * tp_size,
bias=bias,
quant_config=qconfig,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=self.kv_size * tp_size,
bias=bias,
quant_config=qconfig,
prefix=f"{prefix}.v_proj")
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
attn_type=attn_type,
prefix=f"{prefix}.attn",
**{
"layer_idx": extract_layer_index(prefix),
"dual_chunk_attention_config": dual_chunk_attention_config,
} if dual_chunk_attention_config else {})
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if attn_metadata is None:
## for dummy run
return hidden_states
seq_len = hidden_states.shape[-2]
decode_seql = 512
# numa weight and not use mrope (qwen-vl)
if ((hasattr(self.q_proj, "qweight")
and len(self.q_proj.qweight.shape) == 3) or
(hasattr(self.q_proj, "weight")
and len(self.q_proj.weight.shape) == 3)) and not isinstance(
self.rotary_emb, MRotaryEmbedding) and seq_len <= decode_seql:
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.attn.layer_name]
kv_cache = self.attn.kv_cache[forward_context.virtual_engine]
if kv_cache is not None:
with record_function('attention qkv_rope'):
# int8 weight version
q_weight = self.q_proj.qweight if hasattr(
self.q_proj, "qweight") else self.q_proj.weight
k_weight = self.k_proj.qweight if hasattr(
self.k_proj, "qweight") else self.k_proj.weight
v_weight = self.v_proj.qweight if hasattr(
self.v_proj, "qweight") else self.v_proj.weight
q_scale = self.q_proj.scales if hasattr(
self.q_proj, "scales") else None
k_scale = self.k_proj.scales if hasattr(
self.k_proj, "scales") else None
v_scale = self.v_proj.scales if hasattr(
self.v_proj, "scales") else None
q_bias = self.q_proj.bias if hasattr(self.q_proj,
"bias") else None
k_bias = self.k_proj.bias if hasattr(self.k_proj,
"bias") else None
v_bias = self.v_proj.bias if hasattr(self.v_proj,
"bias") else None
q, k, v = torch_br.supa_qkv_rope_decode_infer(
hidden_states,
q_weight,
k_weight,
v_weight,
self.rotary_emb.sin_cache,
self.rotary_emb.cos_cache,
kv_cache,
positions,
attn_metadata.slot_mapping,
self.rotary_emb.head_size,
self.q_size,
self.kv_size,
q_scale=q_scale,
k_scale=k_scale,
v_scale=v_scale,
q_bias=q_bias,
k_bias=k_bias,
v_bias=v_bias)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = False
with record_function('attention'):
attn_output = self.attn(q, k, v)
with record_function('attention o_proj'):
output, _ = self.o_proj(attn_output)
return output
else:
return hidden_states
else:
# uma weight or use mrope (qwen-vl)
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
q, k = self.rotary_emb(positions, q, k)
if hasattr(attn_metadata, 'do_cache'):
attn_metadata.do_cache = True
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output

View File

@@ -0,0 +1,210 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.attention import Attention
from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.mla import MLAModules
from vllm.model_executor.layers.quantization import QuantizationConfig
@dataclass
class SupaMLAModules(MLAModules):
q_a_proj: Optional[torch.nn.Module]
@CustomOp.register("supa_multi_head_latent_attention")
class SupaMultiHeadLatentAttention(CustomOp):
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.fused_qkv_a_proj = mla_modules.fused_qkv_a_proj
self.kv_a_proj_with_mqa = mla_modules.kv_a_proj_with_mqa
self.q_a_layernorm = mla_modules.q_a_layernorm
self.q_b_proj = mla_modules.q_b_proj
self.q_proj = mla_modules.q_proj
self.kv_a_layernorm = mla_modules.kv_a_layernorm
self.kv_b_proj = mla_modules.kv_b_proj
self.rotary_emb = mla_modules.rotary_emb
self.o_proj = mla_modules.o_proj
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
self.q_a_proj = mla_modules.q_a_proj
if self.indexer is not None:
assert hasattr(self.indexer, "topk_tokens")
self.topk_tokens = self.indexer.topk_tokens
self.topk_indices_buffer = mla_modules.topk_indices_buffer
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
if self.is_sparse:
self.mla_attn = Attention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
)
else:
self.mla_attn = Attention(
num_heads=self.num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=mla_modules.is_sparse,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
kv_b_proj=self.kv_b_proj,
indexer=self.indexer,
# BIREN args for fused MLA
rotary_emb=self.rotary_emb,
q_proj=self.q_proj
if self.q_lora_rank is None else self.q_b_proj,
o_proj=self.o_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
q_a_proj=None if self.q_lora_rank is None else self.q_a_proj,
q_a_layernorm=None
if self.q_lora_rank is None else self.q_a_layernorm,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward_native(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
q_c = None
kv_lora = None
if self.q_lora_rank is not None:
assert self.fused_qkv_a_proj is not None, \
"fused_qkv_a_proj is required when q_lora_rank is not None"
assert self.q_a_layernorm is not None, \
"q_a_layernorm is required when q_lora_rank is not None"
assert self.q_b_proj is not None, \
"q_b_proj is required when q_lora_rank is not None"
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_lora = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
q = self.q_b_proj(q_c)[0].view(-1,
self.num_heads * self.qk_head_dim)
else:
assert self.kv_a_proj_with_mqa is not None, \
"kv_a_proj_with_mqa is required when q_lora_rank is None"
assert self.q_proj is not None, \
"q_proj is required when q_lora_rank is None"
kv_lora = self.kv_a_proj_with_mqa(hidden_states)[0]
q = self.q_proj(hidden_states)[0]
kv_lora = kv_lora.view(-1, self.kv_lora_rank + self.qk_rope_head_dim)
kv_c, k_pe = kv_lora.split([self.kv_lora_rank, self.qk_rope_head_dim],
dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c)
q = q.view(-1, self.num_heads, self.qk_head_dim)
# Add head dim of 1 to k_pe
k_pe = k_pe.unsqueeze(1)
q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb(
positions, q[..., self.qk_nope_head_dim:], k_pe)
if self.indexer and self.is_sparse:
_topk_indices = self.indexer(hidden_states, q_c, positions,
self.rotary_emb)
seq_len = hidden_states.shape[1]
attn_out = self.mla_attn(q,
kv_c_normed,
k_pe,
output_shape=(seq_len, self.num_heads *
self.v_head_dim))
return self.o_proj(attn_out)[0].unsqueeze(0)
def forward_supa(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
return self.mla_attn(hidden_states,
positions,
hidden_states,
output_shape=hidden_states.shape)
def forward_oot(self, *args, is_ds_v32: Optional[int], **kwargs):
if is_ds_v32:
return self.forward_native(*args, **kwargs)
else:
return self.forward_supa(*args, **kwargs)

View File

@@ -0,0 +1,170 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch_br
from torch import nn
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_pp_group,
get_tensor_model_parallel_rank)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm_br import envs
from vllm_br.utils import get_grandparent_pid
class LlamaMlpSiluL3(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.gate_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj")
self.up_proj = ColumnParallelLinear(input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate, _ = self.gate_proj(x)
up, _ = self.up_proj(x)
x = torch_br.supa_silumul(gate, up)
x, _ = self.down_proj(x)
return x
class MergedGateUpMLPSiluL2(nn.Module):
"""
"""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.intermediate_size = intermediate_size
self.prefix = prefix
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.gate_up_proj.has_cross_weight = True
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
if "shared_experts" not in self.prefix:
quant_flag = hasattr(self.gate_up_proj, "qweight")
hidden_size = x.shape[-1]
seq_len = x.shape[-2]
gu_weight = self.gate_up_proj.qweight if quant_flag else self.gate_up_proj.weight
gu_scales = self.gate_up_proj.scales if quant_flag else None
gate_up_output = torch_br.br_fused_mlp_infer(
x, [gu_weight],
output_w=self.intermediate_size // self.tp_size,
scales=[gu_scales] if gu_scales is not None else None,
activation_mode="act_swiglu")
down_weight = self.down_proj.qweight if quant_flag else self.down_proj.weight
down_scales = self.down_proj.scales if quant_flag else None
# bypass tp8 and tp4pp2 allreduce
pp_size = get_pp_group().world_size
all_rank = self.tp_size * pp_size
support_types = ((16, 4), (32, 2), (32, 4))
if all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and \
(envs.VLLM_BR_DEVICE_SPC_NUM, self.tp_size) in support_types:
tp_rank = get_tp_group().rank_in_group
global_rank = get_tp_group().rank
rank_i = global_rank % self.tp_size
assert rank_i == tp_rank
down_output = torch_br.supa_fused_linear_allreduce_opt(
gate_up_output,
down_weight,
hidden_size,
tp_rank,
self.tp_size,
global_rank,
0,
scales=down_scales)
return down_output
else:
down_output = torch_br.br_fused_mlp_infer(
gate_up_output, [down_weight],
output_w=hidden_size,
scales=[down_scales] if down_scales is not None else None)
if self.tp_size > 1:
out = down_output
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and out.shape[
1] <= 32:
tp_rank = get_tensor_model_parallel_rank()
output = torch_br.supa_allreduce_pcie_infer(
out, tp_rank, self.tp_size, self.grandparent_pid)
else:
output = tensor_model_parallel_all_reduce(out)
return output
else:
return down_output
else:
return self.gate_up_proj.weight, self.down_proj.weight

View File

@@ -0,0 +1,116 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2MLP,
ParallelConfig)
from vllm_br import envs
from vllm_br.utils import get_grandparent_pid
class DeepseekV2MoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
parallel_config: ParallelConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.static_moe_decoder_max_len = 512
self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
if config.hidden_act != "silu":
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, device="cpu"))
else:
self.gate.e_score_correction_bias = None
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=False,
prefix=f"{prefix}.shared_experts",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
self, "grandparent_pid"):
self.grandparent_pid = get_grandparent_pid()
orig_shape = hidden_states.shape
assert self.n_shared_experts is not None, 'n_shared_experts must be set'
# NOTE: gate has been fused with shared_experts, no more single gate call
# and we packed router weights, shared_experts weights and down weights in a tuple
tuple_router_shared_expert_weight = (
self.gate.weight, self.shared_experts.gate_up_proj.weight,
self.shared_experts.down_proj.weight)
hidden_states = hidden_states.view(-1, orig_shape[-1])
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=tuple_router_shared_expert_weight)
if hasattr(final_hidden_states, 'all_reduced'):
# NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe
delattr(final_hidden_states, 'all_reduced')
elif self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(orig_shape)

View File

@@ -0,0 +1,86 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
# SPDX-License-Identifier: Apache-2.0
# Adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
# Copyright 2025 The vLLM team.
# Copyright 2025 The Qwen Team.
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen3-VL model compatible with HuggingFace weights."""
import torch
import vllm
from vllm.model_executor.models.utils import (_embedding_count_expression,
_flatten_embeddings)
from vllm.multimodal import NestedTensors
def _merge_multimodal_embeddings_fit(
inputs_embeds: torch.Tensor,
is_multimodal: torch.Tensor,
multimodal_embeddings: NestedTensors,
) -> torch.Tensor:
"""
Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the
positions in ``inputs_embeds`` corresponding to placeholder tokens in
``input_ids``.
Note:
This updates ``inputs_embeds`` in place.
"""
flattened = _flatten_embeddings(multimodal_embeddings)
try:
# This is equivalent to: inputs_embeds[is_multimodal] = flattened.
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
# flattened.to(dtype=inputs_embeds.dtype))
inputs_embeds[is_multimodal] = flattened
except RuntimeError as e:
num_expected_tokens = is_multimodal.sum().item()
assert isinstance(num_expected_tokens, int)
if flattened.shape[0] != num_expected_tokens:
expr = _embedding_count_expression(multimodal_embeddings)
raise ValueError(
f"Attempted to assign {expr} = {flattened.shape[0]} "
f"multimodal tokens to {num_expected_tokens} placeholders"
) from e
else:
raise ValueError("Error during masked scatter operation") from e
return inputs_embeds
vllm.model_executor.models.utils._merge_multimodal_embeddings = _merge_multimodal_embeddings_fit