133 lines
11 KiB
Python
133 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
||
import torch
|
||
from torch import nn
|
||
|
||
from vllm.distributed import (get_tp_group, tensor_model_parallel_all_reduce)
|
||
|
||
from vllm.forward_context import ForwardContext, get_forward_context
|
||
|
||
from .vars import *
|
||
|
||
|
||
class BertLayer(nn.Module):
|
||
def forward(self, hidden_states: torch.Tensor):
|
||
if USE_FUSED_BERT_ATTENTION:
|
||
tp_group = get_tp_group()
|
||
world_size = tp_group.world_size
|
||
rank = tp_group.rank_in_group
|
||
total_bytes = hidden_states.numel() * hidden_states.element_size() * world_size
|
||
|
||
forward_context: ForwardContext = get_forward_context()
|
||
attn_metadata_all = forward_context.attn_metadata
|
||
|
||
if isinstance(attn_metadata_all, dict):
|
||
attn_metadata = attn_metadata_all.items().__iter__().__next__()[1]
|
||
else:
|
||
attn_metadata = attn_metadata_all
|
||
|
||
# (matmul + bias_add) with TP + all_reduce结构,为了避免重复加bias,只对rank 0下发bias做bias add,
|
||
# bert layer里BertSelfOutput和BertOutput模块存在这种结构,对应的bias参数是下面的self_bias和output_bias
|
||
if total_bytes < 4194304 or world_size == 1:
|
||
# 1. TP场景all_reduce输入小于4MB时会在以下融合算子里调用dsp all_reduce,大于等于4MB时由于限制需要在外面调用vccl all_reduce
|
||
# 2. 没有TP的场景也会调用下面的融合算子
|
||
output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||
self_weight=self.attention.output.dense.weight,
|
||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||
intermediate_weight=self.intermediate.dense.weight,
|
||
intermediate_bias=self.intermediate.dense.bias,
|
||
output_weight=self.output.dense.weight,
|
||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||
output_norm_weight=self.output.LayerNorm.weight,
|
||
output_norm_bias=self.output.LayerNorm.bias,
|
||
dense_out=torch.Tensor(),
|
||
seqs=attn_metadata.seq_lens,
|
||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.FullStage,
|
||
sm_scale=self.attention.self.scaling,
|
||
num_q_heads=self.attention.self.num_heads * world_size,
|
||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||
flash_attention=False,
|
||
reduce_result=True if world_size > 1 else False,
|
||
world_size=world_size,
|
||
rank=rank,
|
||
group_id=tp_group.group_id,
|
||
dev_info=tp_group.rank_device_infos)
|
||
else:
|
||
attn_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||
qkv_weight=self.attention.self.qkv_proj.weight,
|
||
qkv_bias=self.attention.self.qkv_proj.bias,
|
||
self_weight=self.attention.output.dense.weight,
|
||
self_bias=self.attention.output.dense.bias if rank == 0 else torch.Tensor(),
|
||
self_norm_weight=torch.Tensor(),
|
||
self_norm_bias=torch.Tensor(),
|
||
intermediate_weight=torch.Tensor(),
|
||
intermediate_bias=torch.Tensor(),
|
||
output_weight=torch.Tensor(),
|
||
output_bias=torch.Tensor(),
|
||
output_norm_weight=torch.Tensor(),
|
||
output_norm_bias=torch.Tensor(),
|
||
dense_out=torch.Tensor(),
|
||
seqs=attn_metadata.seq_lens,
|
||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.AttnOutStage,
|
||
sm_scale=self.attention.self.scaling,
|
||
num_q_heads=self.attention.self.num_heads * world_size,
|
||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||
flash_attention=False,
|
||
reduce_result=False,
|
||
world_size=world_size,
|
||
rank=rank,
|
||
group_id=tp_group.group_id,
|
||
dev_info=tp_group.rank_device_infos)
|
||
if world_size > 1:
|
||
attn_out_stage_output = tensor_model_parallel_all_reduce(attn_out_stage_output)
|
||
if USE_FUSED_MLP_VISION:
|
||
attn_output = self.attention.output.LayerNorm(attn_out_stage_output + hidden_states)
|
||
inter_out_stage_output = torch.vacc.fuse_mlp_vision(src=attn_output,
|
||
weights_13=self.intermediate.dense.weight,
|
||
weights_2=self.output.dense.weight,
|
||
weights_13_bias=self.intermediate.dense.bias,
|
||
weights_2_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||
act_type=0 # gelu
|
||
)
|
||
else:
|
||
inter_out_stage_output = torch.vacc.fused_attn_bert_allreduce(hidden_states=hidden_states,
|
||
qkv_weight=torch.Tensor(),
|
||
qkv_bias=torch.Tensor(),
|
||
self_weight=torch.Tensor(),
|
||
self_bias=torch.Tensor(),
|
||
self_norm_weight=self.attention.output.LayerNorm.weight,
|
||
self_norm_bias=self.attention.output.LayerNorm.bias,
|
||
intermediate_weight=self.intermediate.dense.weight,
|
||
intermediate_bias=self.intermediate.dense.bias,
|
||
output_weight=self.output.dense.weight,
|
||
output_bias=self.output.dense.bias if rank == 0 else torch.Tensor(),
|
||
output_norm_weight=torch.Tensor(),
|
||
output_norm_bias=torch.Tensor(),
|
||
dense_out=attn_out_stage_output,
|
||
seqs=attn_metadata.seq_lens,
|
||
vnnlBertKind=torch.vacc.BERT_ATTN_STAGE.InterOutStage,
|
||
sm_scale=self.attention.self.scaling,
|
||
num_q_heads=self.attention.self.num_heads * world_size,
|
||
num_kv_heads=self.attention.self.num_kv_heads * world_size,
|
||
flash_attention=False,
|
||
reduce_result=False,
|
||
world_size=world_size,
|
||
rank=rank,
|
||
group_id=tp_group.group_id,
|
||
dev_info=tp_group.rank_device_infos)
|
||
if world_size > 1:
|
||
inter_out_stage_output = tensor_model_parallel_all_reduce(inter_out_stage_output)
|
||
if USE_FUSED_MLP_VISION:
|
||
output = self.output.LayerNorm(inter_out_stage_output + attn_output)
|
||
else:
|
||
output = self.output.LayerNorm(inter_out_stage_output + attn_out_stage_output)
|
||
else:
|
||
attn_output = self.attention(hidden_states)
|
||
intermediate_output = self.intermediate(attn_output)
|
||
output = self.output(intermediate_output, attn_output)
|
||
return output |