Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/model_executor/models/roberta.py
2026-04-02 04:55:00 +00:00

27 lines
1.0 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch import nn
from vllm.model_executor.models.bert import _decode_token_type_ids
class RobertaEmbedding(nn.Module):
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids)
# position_embeddings = self.position_embeddings(position_ids)
# token_type_embeddings = self.token_type_embeddings(token_type_ids)
# embeddings = inputs_embeds + token_type_embeddings + position_embeddings
# embeddings = self.LayerNorm(embeddings)
embeddings = torch.vacc.fuse_bge_embedding_stage1(inputs_embeds, position_ids, self.position_embeddings.weight, token_type_ids, self.token_type_embeddings.weight, self.LayerNorm.weight, self.LayerNorm.bias, self.LayerNorm.eps)
return embeddings