118 lines
4.9 KiB
Python
118 lines
4.9 KiB
Python
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import RobertaConfig
|
|
|
|
from vllm.attention import AttentionMetadata
|
|
from vllm.config import VllmConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding)
|
|
from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
class RobertaEmbedding(nn.Module):
|
|
|
|
def __init__(self, config: RobertaConfig):
|
|
super().__init__()
|
|
self.size = config.hidden_size
|
|
self.word_embeddings = VocabParallelEmbedding(config.vocab_size,
|
|
config.hidden_size)
|
|
self.padding_idx = config.pad_token_id
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings,
|
|
config.hidden_size,
|
|
padding_idx=self.padding_idx)
|
|
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size,
|
|
config.hidden_size)
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size,
|
|
eps=config.layer_norm_eps)
|
|
self.position_ids = nn.Parameter(
|
|
torch.empty((1, config.max_position_embeddings)), )
|
|
|
|
self.position_embedding_type = config.position_embedding_type
|
|
if self.position_embedding_type != "absolute":
|
|
raise ValueError("Only 'absolute' position_embedding_type" +
|
|
" is supported")
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
input_shape = input_ids.size()
|
|
|
|
# Input embeddings.
|
|
inputs_embeds = self.word_embeddings(input_ids)
|
|
|
|
# TODO: figure out if there is a better way
|
|
# to make to make position ids start at padding_idx + 1
|
|
# References:
|
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133
|
|
# - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669
|
|
position_ids += self.padding_idx + 1
|
|
|
|
# Position embeddings.
|
|
position_embeddings = self.position_embeddings(position_ids)
|
|
|
|
# Token type embeddings. (TODO: move off hotpath?)
|
|
token_type_embeddings = self.token_type_embeddings(
|
|
torch.zeros(input_shape,
|
|
dtype=torch.long,
|
|
device=inputs_embeds.device))
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
|
embeddings = self.LayerNorm(embeddings)
|
|
return embeddings
|
|
|
|
|
|
class RobertaEmbeddingModel(BertEmbeddingModel):
|
|
"""A model that uses Roberta to provide embedding functionalities.
|
|
|
|
This class encapsulates the BertModel and provides an interface for
|
|
embedding operations and customized pooling functions.
|
|
|
|
Attributes:
|
|
model: An instance of BertModel used for forward operations.
|
|
_pooler: An instance of Pooler used for pooling operations.
|
|
"""
|
|
|
|
def _build_model(self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "") -> BertModel:
|
|
return BertModel(vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
embedding_class=RobertaEmbedding)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: Optional[torch.Tensor],
|
|
positions: torch.Tensor,
|
|
kv_caches: List[torch.Tensor],
|
|
attn_metadata: AttentionMetadata,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
|
|
# Verify assumption that position are always a sequence from
|
|
# 0 to N. (Actually here we just check 0 and N to simplify).
|
|
# This is important to fix the position which are assumed to
|
|
# start from padding_idx + 1 instead of 0 in the Roberta models.
|
|
assert hasattr(attn_metadata, "seq_lens_tensor")
|
|
cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0)
|
|
start_pos = torch.cat(
|
|
(torch.tensor([0], device=attn_metadata.seq_lens_tensor.device),
|
|
cumulative[:-1]))
|
|
assert len(torch.nonzero(positions[start_pos])) == 0
|
|
end_pos = cumulative - 1
|
|
last_tokens = attn_metadata.seq_lens_tensor - 1
|
|
assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0
|
|
|
|
return super().forward(input_ids=input_ids,
|
|
positions=positions,
|
|
kv_caches=kv_caches,
|
|
attn_metadata=attn_metadata,
|
|
intermediate_tensors=intermediate_tensors,
|
|
inputs_embeds=inputs_embeds)
|