Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/model_executor/models/bert.py
2026-04-02 04:55:00 +00:00

133 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
class BertLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
if USE_FUSED_BERT_ATTENTION:
tp_group = get_tp_group()
world_size = tp_group.world_size
rank = tp_group.rank_in_group
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
else:
attn_metadata = attn_metadata_all
# (matmul + bias_add) with TP + all_reduce结构为了避免重复加bias只对rank 0下发bias做bias add
# bert layer里BertSelfOutput和BertOutput模块存在这种结构对应的bias参数是下面的self_bias和output_bias
if total_bytes < 4194304 or world_size == 1:
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce大于等于4MB时由于限制需要在外面调用vccl all_reduce
# 2. 没有TP的场景也会调用下面的融合算子
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=self.output.LayerNorm.weight,
output_norm_bias=self.output.LayerNorm.bias,
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=True if world_size > 1 else False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
else:
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=torch.Tensor(),
self_norm_bias=torch.Tensor(),
intermediate_weight=torch.Tensor(),
intermediate_bias=torch.Tensor(),
output_weight=torch.Tensor(),
output_bias=torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
if USE_FUSED_MLP_VISION:
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
weights_13=self.intermediate.dense.weight,
weights_2=self.output.dense.weight,
weights_13_bias=self.intermediate.dense.bias,
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
act_type=0 # gelu
)
else:
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=torch.Tensor(),
qkv_bias=torch.Tensor(),
self_weight=torch.Tensor(),
self_bias=torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=attn_out_stage_output,
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
if USE_FUSED_MLP_VISION:
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
else:
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
else:
attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output