Files

133 lines
11 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
from vllm.forward_context import ForwardContext, get_forward_context
from .vars import *
class BertLayer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
if USE_FUSED_BERT_ATTENTION:
tp_group = get_tp_group()
world_size = tp_group.world_size
rank = tp_group.rank_in_group
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
forward_context: ForwardContext = get_forward_context()
attn_metadata_all = forward_context.attn_metadata
if isinstance(attn_metadata_all, dict):
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
else:
attn_metadata = attn_metadata_all
# (matmul + bias_add) with TP + all_reduce结构为了避免重复加bias只对rank 0下发bias做bias add
# bert layer里BertSelfOutput和BertOutput模块存在这种结构对应的bias参数是下面的self_bias和output_bias
if total_bytes < 4194304 or world_size == 1:
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce大于等于4MB时由于限制需要在外面调用vccl all_reduce
# 2. 没有TP的场景也会调用下面的融合算子
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=self.output.LayerNorm.weight,
output_norm_bias=self.output.LayerNorm.bias,
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=True if world_size > 1 else False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
else:
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=self.attention.self.qkv_proj.weight,
qkv_bias=self.attention.self.qkv_proj.bias,
self_weight=self.attention.output.dense.weight,
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
self_norm_weight=torch.Tensor(),
self_norm_bias=torch.Tensor(),
intermediate_weight=torch.Tensor(),
intermediate_bias=torch.Tensor(),
output_weight=torch.Tensor(),
output_bias=torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=torch.Tensor(),
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
if USE_FUSED_MLP_VISION:
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
weights_13=self.intermediate.dense.weight,
weights_2=self.output.dense.weight,
weights_13_bias=self.intermediate.dense.bias,
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
act_type=0 # gelu
)
else:
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
qkv_weight=torch.Tensor(),
qkv_bias=torch.Tensor(),
self_weight=torch.Tensor(),
self_bias=torch.Tensor(),
self_norm_weight=self.attention.output.LayerNorm.weight,
self_norm_bias=self.attention.output.LayerNorm.bias,
intermediate_weight=self.intermediate.dense.weight,
intermediate_bias=self.intermediate.dense.bias,
output_weight=self.output.dense.weight,
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
output_norm_weight=torch.Tensor(),
output_norm_bias=torch.Tensor(),
dense_out=attn_out_stage_output,
seqs=attn_metadata.seq_lens,
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
sm_scale=self.attention.self.scaling,
num_q_heads=self.attention.self.num_heads * world_size,
num_kv_heads=self.attention.self.num_kv_heads * world_size,
flash_attention=False,
reduce_result=False,
world_size=world_size,
rank=rank,
group_id=tp_group.group_id,
dev_info=tp_group.rank_device_infos)
if world_size > 1:
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
if USE_FUSED_MLP_VISION:
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
else:
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
else:
attn_output = self.attention(hidden_states)
intermediate_output = self.intermediate(attn_output)
output = self.output(intermediate_output, attn_output)
return output