333 lines
11 KiB
Python
333 lines
11 KiB
Python
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from abc import abstractmethod
|
||
|
|
from torch import tensor as Tensor
|
||
|
|
from typing import List, Any
|
||
|
|
|
||
|
|
|
||
|
|
class BaseVAE(nn.Module):
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
super(BaseVAE, self).__init__()
|
||
|
|
|
||
|
|
def encode(self, input: Tensor) -> List[Tensor]:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def decode(self, input: Tensor) -> Any:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
def generate(self, x: Tensor, **kwargs) -> Tensor:
|
||
|
|
raise NotImplementedError
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def forward(self, *inputs: Tensor) -> Tensor:
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class MLPAE(BaseVAE):
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
in_channels: int,
|
||
|
|
semantic_latent_dim: int,
|
||
|
|
truthful_latent_dim: int,
|
||
|
|
semantic_hidden_dims: List = None,
|
||
|
|
truthful_hidden_dims: List = None,
|
||
|
|
decoder_hidden_dims: List = None,
|
||
|
|
**kwargs
|
||
|
|
) -> None:
|
||
|
|
super(MLPAE, self).__init__()
|
||
|
|
|
||
|
|
self.semantic_latent_dim = semantic_latent_dim
|
||
|
|
|
||
|
|
if semantic_hidden_dims is None:
|
||
|
|
semantic_hidden_dims = []
|
||
|
|
|
||
|
|
# Build Semantic Encoder
|
||
|
|
semantic_encoder_modules = []
|
||
|
|
flat_size = in_channels
|
||
|
|
for h_dim in semantic_hidden_dims:
|
||
|
|
semantic_encoder_modules.append(
|
||
|
|
nn.Sequential(
|
||
|
|
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU()
|
||
|
|
)
|
||
|
|
)
|
||
|
|
flat_size = h_dim
|
||
|
|
semantic_encoder_modules.append(
|
||
|
|
nn.Sequential(
|
||
|
|
nn.Linear(flat_size, semantic_latent_dim),
|
||
|
|
nn.LayerNorm(semantic_latent_dim),
|
||
|
|
nn.LeakyReLU(),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
self.semantic_encoder = nn.Sequential(*semantic_encoder_modules)
|
||
|
|
|
||
|
|
if truthful_hidden_dims is None:
|
||
|
|
truthful_hidden_dims = []
|
||
|
|
|
||
|
|
# Build Truthful Encoder
|
||
|
|
truthful_encoder_modules = []
|
||
|
|
flat_size = in_channels
|
||
|
|
for h_dim in truthful_hidden_dims:
|
||
|
|
truthful_encoder_modules.append(
|
||
|
|
nn.Sequential(
|
||
|
|
(
|
||
|
|
nn.Linear(flat_size, h_dim)
|
||
|
|
if flat_size != h_dim
|
||
|
|
else nn.Identity()
|
||
|
|
),
|
||
|
|
nn.LayerNorm(h_dim),
|
||
|
|
nn.LeakyReLU(),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
flat_size = h_dim
|
||
|
|
truthful_encoder_modules.append(
|
||
|
|
nn.Sequential(
|
||
|
|
(
|
||
|
|
nn.Linear(flat_size, truthful_latent_dim)
|
||
|
|
if flat_size != truthful_latent_dim
|
||
|
|
else nn.Identity()
|
||
|
|
),
|
||
|
|
nn.LayerNorm(truthful_latent_dim),
|
||
|
|
nn.LeakyReLU(),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
self.truthful_encoder = nn.Sequential(*truthful_encoder_modules)
|
||
|
|
|
||
|
|
# Cross-Attention Module
|
||
|
|
self.num_heads = 1
|
||
|
|
self.cross_attention = nn.MultiheadAttention(
|
||
|
|
embed_dim=semantic_latent_dim, num_heads=self.num_heads
|
||
|
|
)
|
||
|
|
|
||
|
|
self.proj = None
|
||
|
|
if semantic_latent_dim != truthful_latent_dim:
|
||
|
|
self.proj = nn.Linear(truthful_latent_dim, semantic_latent_dim, bias=False)
|
||
|
|
|
||
|
|
# Build Decoder
|
||
|
|
decoder_modules = []
|
||
|
|
if len(decoder_hidden_dims) > 0:
|
||
|
|
flat_size = semantic_latent_dim
|
||
|
|
for h_dim in decoder_hidden_dims:
|
||
|
|
decoder_modules.append(
|
||
|
|
nn.Sequential(
|
||
|
|
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU()
|
||
|
|
)
|
||
|
|
)
|
||
|
|
flat_size = h_dim
|
||
|
|
|
||
|
|
flat_size = decoder_hidden_dims[-1]
|
||
|
|
self.decoder = nn.Sequential(*decoder_modules)
|
||
|
|
else:
|
||
|
|
self.decoder_input = None
|
||
|
|
|
||
|
|
self.decoder = None
|
||
|
|
flat_size = semantic_latent_dim
|
||
|
|
self.final_layer = nn.Sequential(nn.Linear(flat_size, in_channels))
|
||
|
|
|
||
|
|
def encode_semantic(self, input: Tensor) -> List[Tensor]:
|
||
|
|
semantic_latent_rep = self.semantic_encoder(input)
|
||
|
|
return semantic_latent_rep
|
||
|
|
|
||
|
|
def encode_truthful(self, input: Tensor) -> List[Tensor]:
|
||
|
|
truthful_latent_rep = self.truthful_encoder(input)
|
||
|
|
truthful_latent_rep = F.normalize(truthful_latent_rep, p=2, dim=-1)
|
||
|
|
|
||
|
|
return truthful_latent_rep
|
||
|
|
|
||
|
|
def attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||
|
|
if self.proj is not None and query.size(-1) != key.size(-1):
|
||
|
|
key = self.proj(key)
|
||
|
|
value = self.proj(value)
|
||
|
|
query = query.unsqueeze(0)
|
||
|
|
key = key.unsqueeze(0)
|
||
|
|
value = value.unsqueeze(0)
|
||
|
|
|
||
|
|
output, attention_weights = self.cross_attention(query, key, value)
|
||
|
|
|
||
|
|
return output[0]
|
||
|
|
|
||
|
|
def decode(self, z: Tensor) -> Tensor:
|
||
|
|
result = z
|
||
|
|
if self.decoder is not None:
|
||
|
|
result = self.decoder(result)
|
||
|
|
result = self.final_layer(result)
|
||
|
|
return result
|
||
|
|
|
||
|
|
def forward(
|
||
|
|
self, input: Tensor, truthful_latent_rep=None, **kwargs
|
||
|
|
) -> List[Tensor]:
|
||
|
|
semantic_latent_rep = self.encode_semantic(input)
|
||
|
|
if truthful_latent_rep is None:
|
||
|
|
truthful_latent_rep = self.encode_truthful(input)
|
||
|
|
truthful_latent_rep = truthful_latent_rep.reshape(
|
||
|
|
-1, truthful_latent_rep.size(-1)
|
||
|
|
)
|
||
|
|
z = semantic_latent_rep + self.attention(
|
||
|
|
semantic_latent_rep,
|
||
|
|
truthful_latent_rep.contiguous(),
|
||
|
|
truthful_latent_rep.contiguous(),
|
||
|
|
)
|
||
|
|
output = self.decode(z)
|
||
|
|
|
||
|
|
return [output, input, semantic_latent_rep, truthful_latent_rep]
|
||
|
|
|
||
|
|
def forward_decoder(self, input, semantic_latent_rep, truthful_latent_rep):
|
||
|
|
z = semantic_latent_rep + self.attention(
|
||
|
|
semantic_latent_rep, truthful_latent_rep, truthful_latent_rep
|
||
|
|
)
|
||
|
|
output = self.decode(z)
|
||
|
|
return [output, input, semantic_latent_rep, truthful_latent_rep]
|
||
|
|
|
||
|
|
def get_semantic_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]:
|
||
|
|
semantic_latent_rep = self.encode_semantic(input)
|
||
|
|
return semantic_latent_rep
|
||
|
|
|
||
|
|
def get_truthful_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]:
|
||
|
|
truthful_latent_rep = self.encode_truthful(input)
|
||
|
|
return truthful_latent_rep
|
||
|
|
|
||
|
|
def loss_function(self, *args, **kwargs) -> dict:
|
||
|
|
recons = args[0]
|
||
|
|
input = args[1]
|
||
|
|
recons_loss = F.mse_loss(recons, input)
|
||
|
|
|
||
|
|
loss = recons_loss
|
||
|
|
return {"loss": loss, "Reconstruction_Loss": recons_loss.detach()}
|
||
|
|
|
||
|
|
|
||
|
|
class TruthX:
|
||
|
|
def __init__(self, model_path, hidden_size, edit_strength=1.0, top_layers=10):
|
||
|
|
|
||
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
|
||
|
|
checkpoint = torch.load(model_path)
|
||
|
|
args = checkpoint["args"]
|
||
|
|
|
||
|
|
semantic_latent_dim = args.semantic_latent_dim # Adjust as needed
|
||
|
|
truthful_latent_dim = args.truthful_latent_dim
|
||
|
|
semantic_hidden_dims = (
|
||
|
|
[int(_) for _ in args.semantic_hidden_dims.split(",")]
|
||
|
|
if args.semantic_hidden_dims != ""
|
||
|
|
else []
|
||
|
|
)
|
||
|
|
truthful_hidden_dims = (
|
||
|
|
[int(_) for _ in args.truthful_hidden_dims.split(",")]
|
||
|
|
if args.truthful_hidden_dims != ""
|
||
|
|
else []
|
||
|
|
)
|
||
|
|
decoder_hidden_dims = (
|
||
|
|
[int(_) for _ in args.decoder_hidden_dims.split(",")]
|
||
|
|
if args.decoder_hidden_dims != ""
|
||
|
|
else []
|
||
|
|
)
|
||
|
|
|
||
|
|
ae_model = MLPAE(
|
||
|
|
in_channels=hidden_size,
|
||
|
|
semantic_latent_dim=semantic_latent_dim,
|
||
|
|
truthful_latent_dim=truthful_latent_dim,
|
||
|
|
semantic_hidden_dims=semantic_hidden_dims,
|
||
|
|
truthful_hidden_dims=truthful_hidden_dims,
|
||
|
|
decoder_hidden_dims=decoder_hidden_dims,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
ae_model.load_state_dict(checkpoint["state_dict"])
|
||
|
|
|
||
|
|
ae_model.pos_center = ((checkpoint["pos_center"])).to(device)
|
||
|
|
ae_model.neg_center = ((checkpoint["neg_center"])).to(device)
|
||
|
|
ae_model.eval()
|
||
|
|
ae_model.to(device)
|
||
|
|
self.ae_model = ae_model
|
||
|
|
|
||
|
|
self.rank = checkpoint["rank"]
|
||
|
|
|
||
|
|
self.top_layers = top_layers
|
||
|
|
self.edit_strength = edit_strength
|
||
|
|
self.cur_layer_id = 0
|
||
|
|
self.prompt_length = None
|
||
|
|
self.mc = False
|
||
|
|
|
||
|
|
@torch.inference_mode()
|
||
|
|
def edit(self, X):
|
||
|
|
layer_id = int(self.cur_layer_id.split(".")[0])
|
||
|
|
if self.cur_layer_id.endswith("attn"):
|
||
|
|
layer_id = 2 * layer_id
|
||
|
|
else:
|
||
|
|
layer_id = 2 * layer_id + 1
|
||
|
|
|
||
|
|
if self.rank[layer_id] > self.top_layers:
|
||
|
|
return X
|
||
|
|
|
||
|
|
bsz, s_len, d = X.size()
|
||
|
|
x = (
|
||
|
|
X.contiguous()
|
||
|
|
.view(-1, d)
|
||
|
|
.type_as(self.ae_model.semantic_encoder[0][0].weight)
|
||
|
|
)
|
||
|
|
x_truthful = self.ae_model.get_truthful_latent_rep(
|
||
|
|
X.type_as(self.ae_model.semantic_encoder[0][0].weight)
|
||
|
|
)
|
||
|
|
|
||
|
|
pos_center = self.ae_model.pos_center[layer_id].unsqueeze(0)
|
||
|
|
neg_center = self.ae_model.neg_center[layer_id].unsqueeze(0)
|
||
|
|
|
||
|
|
delta = (pos_center - neg_center).unsqueeze(0)
|
||
|
|
recon_x_pos = (
|
||
|
|
self.ae_model(
|
||
|
|
x,
|
||
|
|
truthful_latent_rep=F.normalize(
|
||
|
|
x_truthful + delta, p=2, dim=-1
|
||
|
|
).type_as(x),
|
||
|
|
)[0]
|
||
|
|
.contiguous()
|
||
|
|
.view(bsz, s_len, d)
|
||
|
|
)
|
||
|
|
recon_x_neg = (
|
||
|
|
self.ae_model(
|
||
|
|
x,
|
||
|
|
truthful_latent_rep=F.normalize(
|
||
|
|
x_truthful - delta, p=2, dim=-1
|
||
|
|
).type_as(x),
|
||
|
|
)[0]
|
||
|
|
.contiguous()
|
||
|
|
.view(bsz, s_len, d)
|
||
|
|
)
|
||
|
|
Delta = recon_x_pos - recon_x_neg
|
||
|
|
Delta = Delta.contiguous().to(X.dtype)
|
||
|
|
Delta = F.normalize(Delta, p=2, dim=-1).type_as(X) * torch.norm(
|
||
|
|
X, p=2, dim=-1
|
||
|
|
).unsqueeze(2)
|
||
|
|
|
||
|
|
mask = torch.ones((bsz, s_len), device=Delta.device)
|
||
|
|
|
||
|
|
if self.mc:
|
||
|
|
# multiple-choice, only edit the tokens in answer
|
||
|
|
mask[:, : self.prompt_length + 1] = 0
|
||
|
|
# probing those untruthful position
|
||
|
|
probing = (
|
||
|
|
torch.nn.functional.cosine_similarity(
|
||
|
|
x_truthful, neg_center.unsqueeze(1), dim=-1
|
||
|
|
)
|
||
|
|
- torch.nn.functional.cosine_similarity(
|
||
|
|
x_truthful, pos_center.unsqueeze(1), dim=-1
|
||
|
|
)
|
||
|
|
).clamp(0, 999)
|
||
|
|
mask = mask * probing
|
||
|
|
|
||
|
|
else:
|
||
|
|
# open-ended generation, only edit the generated token (i.e., last token)
|
||
|
|
mask[:, :-1] = 0
|
||
|
|
mask[:, -1:] = 1
|
||
|
|
|
||
|
|
new_X = X + (Delta.type_as(X)) * self.edit_strength * mask.unsqueeze(2).type_as(X)
|
||
|
|
return new_X
|