初始化项目,由ModelHub XC社区提供模型
Model: ICTNLP/Llama-2-7b-chat-TruthX Source: Original Platform
This commit is contained in:
332
truthx.py
Normal file
332
truthx.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from abc import abstractmethod
|
||||
from torch import tensor as Tensor
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class BaseVAE(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(BaseVAE, self).__init__()
|
||||
|
||||
def encode(self, input: Tensor) -> List[Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def decode(self, input: Tensor) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def generate(self, x: Tensor, **kwargs) -> Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *inputs: Tensor) -> Tensor:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class MLPAE(BaseVAE):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
semantic_latent_dim: int,
|
||||
truthful_latent_dim: int,
|
||||
semantic_hidden_dims: List = None,
|
||||
truthful_hidden_dims: List = None,
|
||||
decoder_hidden_dims: List = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super(MLPAE, self).__init__()
|
||||
|
||||
self.semantic_latent_dim = semantic_latent_dim
|
||||
|
||||
if semantic_hidden_dims is None:
|
||||
semantic_hidden_dims = []
|
||||
|
||||
# Build Semantic Encoder
|
||||
semantic_encoder_modules = []
|
||||
flat_size = in_channels
|
||||
for h_dim in semantic_hidden_dims:
|
||||
semantic_encoder_modules.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU()
|
||||
)
|
||||
)
|
||||
flat_size = h_dim
|
||||
semantic_encoder_modules.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(flat_size, semantic_latent_dim),
|
||||
nn.LayerNorm(semantic_latent_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
)
|
||||
|
||||
self.semantic_encoder = nn.Sequential(*semantic_encoder_modules)
|
||||
|
||||
if truthful_hidden_dims is None:
|
||||
truthful_hidden_dims = []
|
||||
|
||||
# Build Truthful Encoder
|
||||
truthful_encoder_modules = []
|
||||
flat_size = in_channels
|
||||
for h_dim in truthful_hidden_dims:
|
||||
truthful_encoder_modules.append(
|
||||
nn.Sequential(
|
||||
(
|
||||
nn.Linear(flat_size, h_dim)
|
||||
if flat_size != h_dim
|
||||
else nn.Identity()
|
||||
),
|
||||
nn.LayerNorm(h_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
)
|
||||
flat_size = h_dim
|
||||
truthful_encoder_modules.append(
|
||||
nn.Sequential(
|
||||
(
|
||||
nn.Linear(flat_size, truthful_latent_dim)
|
||||
if flat_size != truthful_latent_dim
|
||||
else nn.Identity()
|
||||
),
|
||||
nn.LayerNorm(truthful_latent_dim),
|
||||
nn.LeakyReLU(),
|
||||
)
|
||||
)
|
||||
|
||||
self.truthful_encoder = nn.Sequential(*truthful_encoder_modules)
|
||||
|
||||
# Cross-Attention Module
|
||||
self.num_heads = 1
|
||||
self.cross_attention = nn.MultiheadAttention(
|
||||
embed_dim=semantic_latent_dim, num_heads=self.num_heads
|
||||
)
|
||||
|
||||
self.proj = None
|
||||
if semantic_latent_dim != truthful_latent_dim:
|
||||
self.proj = nn.Linear(truthful_latent_dim, semantic_latent_dim, bias=False)
|
||||
|
||||
# Build Decoder
|
||||
decoder_modules = []
|
||||
if len(decoder_hidden_dims) > 0:
|
||||
flat_size = semantic_latent_dim
|
||||
for h_dim in decoder_hidden_dims:
|
||||
decoder_modules.append(
|
||||
nn.Sequential(
|
||||
nn.Linear(flat_size, h_dim), nn.LayerNorm(h_dim), nn.LeakyReLU()
|
||||
)
|
||||
)
|
||||
flat_size = h_dim
|
||||
|
||||
flat_size = decoder_hidden_dims[-1]
|
||||
self.decoder = nn.Sequential(*decoder_modules)
|
||||
else:
|
||||
self.decoder_input = None
|
||||
|
||||
self.decoder = None
|
||||
flat_size = semantic_latent_dim
|
||||
self.final_layer = nn.Sequential(nn.Linear(flat_size, in_channels))
|
||||
|
||||
def encode_semantic(self, input: Tensor) -> List[Tensor]:
|
||||
semantic_latent_rep = self.semantic_encoder(input)
|
||||
return semantic_latent_rep
|
||||
|
||||
def encode_truthful(self, input: Tensor) -> List[Tensor]:
|
||||
truthful_latent_rep = self.truthful_encoder(input)
|
||||
truthful_latent_rep = F.normalize(truthful_latent_rep, p=2, dim=-1)
|
||||
|
||||
return truthful_latent_rep
|
||||
|
||||
def attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
|
||||
if self.proj is not None and query.size(-1) != key.size(-1):
|
||||
key = self.proj(key)
|
||||
value = self.proj(value)
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
output, attention_weights = self.cross_attention(query, key, value)
|
||||
|
||||
return output[0]
|
||||
|
||||
def decode(self, z: Tensor) -> Tensor:
|
||||
result = z
|
||||
if self.decoder is not None:
|
||||
result = self.decoder(result)
|
||||
result = self.final_layer(result)
|
||||
return result
|
||||
|
||||
def forward(
|
||||
self, input: Tensor, truthful_latent_rep=None, **kwargs
|
||||
) -> List[Tensor]:
|
||||
semantic_latent_rep = self.encode_semantic(input)
|
||||
if truthful_latent_rep is None:
|
||||
truthful_latent_rep = self.encode_truthful(input)
|
||||
truthful_latent_rep = truthful_latent_rep.reshape(
|
||||
-1, truthful_latent_rep.size(-1)
|
||||
)
|
||||
z = semantic_latent_rep + self.attention(
|
||||
semantic_latent_rep,
|
||||
truthful_latent_rep.contiguous(),
|
||||
truthful_latent_rep.contiguous(),
|
||||
)
|
||||
output = self.decode(z)
|
||||
|
||||
return [output, input, semantic_latent_rep, truthful_latent_rep]
|
||||
|
||||
def forward_decoder(self, input, semantic_latent_rep, truthful_latent_rep):
|
||||
z = semantic_latent_rep + self.attention(
|
||||
semantic_latent_rep, truthful_latent_rep, truthful_latent_rep
|
||||
)
|
||||
output = self.decode(z)
|
||||
return [output, input, semantic_latent_rep, truthful_latent_rep]
|
||||
|
||||
def get_semantic_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]:
|
||||
semantic_latent_rep = self.encode_semantic(input)
|
||||
return semantic_latent_rep
|
||||
|
||||
def get_truthful_latent_rep(self, input: Tensor, **kwargs) -> List[Tensor]:
|
||||
truthful_latent_rep = self.encode_truthful(input)
|
||||
return truthful_latent_rep
|
||||
|
||||
def loss_function(self, *args, **kwargs) -> dict:
|
||||
recons = args[0]
|
||||
input = args[1]
|
||||
recons_loss = F.mse_loss(recons, input)
|
||||
|
||||
loss = recons_loss
|
||||
return {"loss": loss, "Reconstruction_Loss": recons_loss.detach()}
|
||||
|
||||
|
||||
class TruthX:
|
||||
def __init__(self, model_path, hidden_size, edit_strength=1.0, top_layers=10):
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
checkpoint = torch.load(model_path)
|
||||
args = checkpoint["args"]
|
||||
|
||||
semantic_latent_dim = args.semantic_latent_dim # Adjust as needed
|
||||
truthful_latent_dim = args.truthful_latent_dim
|
||||
semantic_hidden_dims = (
|
||||
[int(_) for _ in args.semantic_hidden_dims.split(",")]
|
||||
if args.semantic_hidden_dims != ""
|
||||
else []
|
||||
)
|
||||
truthful_hidden_dims = (
|
||||
[int(_) for _ in args.truthful_hidden_dims.split(",")]
|
||||
if args.truthful_hidden_dims != ""
|
||||
else []
|
||||
)
|
||||
decoder_hidden_dims = (
|
||||
[int(_) for _ in args.decoder_hidden_dims.split(",")]
|
||||
if args.decoder_hidden_dims != ""
|
||||
else []
|
||||
)
|
||||
|
||||
ae_model = MLPAE(
|
||||
in_channels=hidden_size,
|
||||
semantic_latent_dim=semantic_latent_dim,
|
||||
truthful_latent_dim=truthful_latent_dim,
|
||||
semantic_hidden_dims=semantic_hidden_dims,
|
||||
truthful_hidden_dims=truthful_hidden_dims,
|
||||
decoder_hidden_dims=decoder_hidden_dims,
|
||||
).to(device)
|
||||
|
||||
ae_model.load_state_dict(checkpoint["state_dict"])
|
||||
|
||||
ae_model.pos_center = ((checkpoint["pos_center"])).to(device)
|
||||
ae_model.neg_center = ((checkpoint["neg_center"])).to(device)
|
||||
ae_model.eval()
|
||||
ae_model.to(device)
|
||||
self.ae_model = ae_model
|
||||
|
||||
self.rank = checkpoint["rank"]
|
||||
|
||||
self.top_layers = top_layers
|
||||
self.edit_strength = edit_strength
|
||||
self.cur_layer_id = 0
|
||||
self.prompt_length = None
|
||||
self.mc = False
|
||||
|
||||
@torch.inference_mode()
|
||||
def edit(self, X):
|
||||
layer_id = int(self.cur_layer_id.split(".")[0])
|
||||
if self.cur_layer_id.endswith("attn"):
|
||||
layer_id = 2 * layer_id
|
||||
else:
|
||||
layer_id = 2 * layer_id + 1
|
||||
|
||||
if self.rank[layer_id] > self.top_layers:
|
||||
return X
|
||||
|
||||
bsz, s_len, d = X.size()
|
||||
x = (
|
||||
X.contiguous()
|
||||
.view(-1, d)
|
||||
.type_as(self.ae_model.semantic_encoder[0][0].weight)
|
||||
)
|
||||
x_truthful = self.ae_model.get_truthful_latent_rep(
|
||||
X.type_as(self.ae_model.semantic_encoder[0][0].weight)
|
||||
)
|
||||
|
||||
pos_center = self.ae_model.pos_center[layer_id].unsqueeze(0)
|
||||
neg_center = self.ae_model.neg_center[layer_id].unsqueeze(0)
|
||||
|
||||
delta = (pos_center - neg_center).unsqueeze(0)
|
||||
recon_x_pos = (
|
||||
self.ae_model(
|
||||
x,
|
||||
truthful_latent_rep=F.normalize(
|
||||
x_truthful + delta, p=2, dim=-1
|
||||
).type_as(x),
|
||||
)[0]
|
||||
.contiguous()
|
||||
.view(bsz, s_len, d)
|
||||
)
|
||||
recon_x_neg = (
|
||||
self.ae_model(
|
||||
x,
|
||||
truthful_latent_rep=F.normalize(
|
||||
x_truthful - delta, p=2, dim=-1
|
||||
).type_as(x),
|
||||
)[0]
|
||||
.contiguous()
|
||||
.view(bsz, s_len, d)
|
||||
)
|
||||
Delta = recon_x_pos - recon_x_neg
|
||||
Delta = Delta.contiguous().to(X.dtype)
|
||||
Delta = F.normalize(Delta, p=2, dim=-1).type_as(X) * torch.norm(
|
||||
X, p=2, dim=-1
|
||||
).unsqueeze(2)
|
||||
|
||||
mask = torch.ones((bsz, s_len), device=Delta.device)
|
||||
|
||||
if self.mc:
|
||||
# multiple-choice, only edit the tokens in answer
|
||||
mask[:, : self.prompt_length + 1] = 0
|
||||
# probing those untruthful position
|
||||
probing = (
|
||||
torch.nn.functional.cosine_similarity(
|
||||
x_truthful, neg_center.unsqueeze(1), dim=-1
|
||||
)
|
||||
- torch.nn.functional.cosine_similarity(
|
||||
x_truthful, pos_center.unsqueeze(1), dim=-1
|
||||
)
|
||||
).clamp(0, 999)
|
||||
mask = mask * probing
|
||||
|
||||
else:
|
||||
# open-ended generation, only edit the generated token (i.e., last token)
|
||||
mask[:, :-1] = 0
|
||||
mask[:, -1:] = 1
|
||||
|
||||
new_X = X + (Delta.type_as(X)) * self.edit_strength * mask.unsqueeze(2).type_as(X)
|
||||
return new_X
|
||||
Reference in New Issue
Block a user