init
This commit is contained in:
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
27
vllm_vacc/vllm/model_executor/models/roberta.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.model_executor.models.bert import _decode_token_type_ids
|
||||
|
||||
|
||||
class RobertaEmbedding(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
# position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
# embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
|
||||
return embeddings
|
||||
Reference in New Issue
Block a user