# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce) from vllm.forward_context import ForwardContext, get_forward_context from .vars import * class BertLayer(nn.Module): def forward(self, hidden_states: torch.Tensor): if USE_FUSED_BERT_ATTENTION: tp_group = get_tp_group() world_size = tp_group.world_size rank = tp_group.rank_in_group total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size forward_context: ForwardContext = get_forward_context() attn_metadata_all = forward_context.attn_metadata if isinstance(attn_metadata_all, dict): attn_metadata = attn_metadata_all.items().__iter__().__next__()[1] else: attn_metadata = attn_metadata_all # (matmul + bias_add) with TP + all_reduce结构,为了避免重复加bias,只对rank 0下发bias做bias add, # bert layer里BertSelfOutput和BertOutput模块存在这种结构,对应的bias参数是下面的self_bias和output_bias if total_bytes < 4194304 or world_size == 1: # 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce,大于等于4MB时由于限制需要在外面调用vccl all_reduce # 2. 没有TP的场景也会调用下面的融合算子 output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states, qkv_weight=self.attention.self.qkv_proj.weight, qkv_bias=self.attention.self.qkv_proj.bias, self_weight=self.attention.output.dense.weight, self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(), self_norm_weight=self.attention.output.LayerNorm.weight, self_norm_bias=self.attention.output.LayerNorm.bias, intermediate_weight=self.intermediate.dense.weight, intermediate_bias=self.intermediate.dense.bias, output_weight=self.output.dense.weight, output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(), output_norm_weight=self.output.LayerNorm.weight, output_norm_bias=self.output.LayerNorm.bias, dense_out=torch.Tensor(), seqs=attn_metadata.seq_lens, vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage, sm_scale=self.attention.self.scaling, num_q_heads=self.attention.self.num_heads * world_size, num_kv_heads=self.attention.self.num_kv_heads * world_size, flash_attention=False, reduce_result=True if world_size > 1 else False, world_size=world_size, rank=rank, group_id=tp_group.group_id, dev_info=tp_group.rank_device_infos) else: attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states, qkv_weight=self.attention.self.qkv_proj.weight, qkv_bias=self.attention.self.qkv_proj.bias, self_weight=self.attention.output.dense.weight, self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(), self_norm_weight=torch.Tensor(), self_norm_bias=torch.Tensor(), intermediate_weight=torch.Tensor(), intermediate_bias=torch.Tensor(), output_weight=torch.Tensor(), output_bias=torch.Tensor(), output_norm_weight=torch.Tensor(), output_norm_bias=torch.Tensor(), dense_out=torch.Tensor(), seqs=attn_metadata.seq_lens, vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage, sm_scale=self.attention.self.scaling, num_q_heads=self.attention.self.num_heads * world_size, num_kv_heads=self.attention.self.num_kv_heads * world_size, flash_attention=False, reduce_result=False, world_size=world_size, rank=rank, group_id=tp_group.group_id, dev_info=tp_group.rank_device_infos) if world_size > 1: attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output) if USE_FUSED_MLP_VISION: attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states) inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output, weights_13=self.intermediate.dense.weight, weights_2=self.output.dense.weight, weights_13_bias=self.intermediate.dense.bias, weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(), act_type=0 # gelu ) else: inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states, qkv_weight=torch.Tensor(), qkv_bias=torch.Tensor(), self_weight=torch.Tensor(), self_bias=torch.Tensor(), self_norm_weight=self.attention.output.LayerNorm.weight, self_norm_bias=self.attention.output.LayerNorm.bias, intermediate_weight=self.intermediate.dense.weight, intermediate_bias=self.intermediate.dense.bias, output_weight=self.output.dense.weight, output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(), output_norm_weight=torch.Tensor(), output_norm_bias=torch.Tensor(), dense_out=attn_out_stage_output, seqs=attn_metadata.seq_lens, vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage, sm_scale=self.attention.self.scaling, num_q_heads=self.attention.self.num_heads * world_size, num_kv_heads=self.attention.self.num_kv_heads * world_size, flash_attention=False, reduce_result=False, world_size=world_size, rank=rank, group_id=tp_group.group_id, dev_info=tp_group.rank_device_infos) if world_size > 1: inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output) if USE_FUSED_MLP_VISION: output = self.output.LayerNorm(inter_out_stage_output + attn_output) else: output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output) else: attn_output = self.attention(hidden_states) intermediate_output = self.intermediate(attn_output) output = self.output(intermediate_output, attn_output) return output