# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from torch import nn from vllm.model_executor.models.bert import _decode_token_type_ids class RobertaEmbedding(nn.Module): def forward( self, input_ids: torch.Tensor, position_ids: torch.Tensor, ) -> torch.Tensor: token_type_ids = _decode_token_type_ids(input_ids) inputs_embeds = self.word_embeddings(input_ids) # position_embeddings = self.position_embeddings(position_ids) # token_type_embeddings = self.token_type_embeddings(token_type_ids) # embeddings = inputs_embeds + token_type_embeddings + position_embeddings # embeddings = self.LayerNorm(embeddings) embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps) return embeddings