27 lines
1.0 KiB
Python
27 lines
1.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
from vllm.model_executor.models.bert import _decode_token_type_ids
|
|
|
|
|
|
class RobertaEmbedding(nn.Module):
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
|
|
token_type_ids = _decode_token_type_ids(input_ids)
|
|
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
# position_embeddings = self.position_embeddings(position_ids)
|
|
|
|
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
|
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
|
# embeddings = self.LayerNorm(embeddings)
|
|
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
|
|
return embeddings |