105 lines
3.1 KiB
Markdown
105 lines
3.1 KiB
Markdown
|
|
---
|
||
|
|
base_model: meta-llama/Llama-3.1-8B-Instruct
|
||
|
|
library_name: transformers
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
language: en
|
||
|
|
---
|
||
|
|
|
||
|
|
# LatentSC Llama 3.1 8B with Summary Tokens
|
||
|
|
|
||
|
|
This repository contains a Llama 3.1 8B Instruct backbone with LatentSC Summary-token embeddings attached. The base model weights are unchanged; only the Summary token embeddings are added so that LatentSC inference can use the trained Summary tokens.
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
|
||
|
|
repo = "jeongseokoh/LatentSC_llama3.1_8b_6SummaryTokens"
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(repo)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
repo, torch_dtype=torch.bfloat16, device_map="auto"
|
||
|
|
)
|
||
|
|
|
||
|
|
# Summary tokens (default: 6)
|
||
|
|
summary_tokens = [f"<|Summary{i}|>" for i in range(1, 7)]
|
||
|
|
|
||
|
|
messages = [
|
||
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
||
|
|
{"role": "user", "content": "Solve: 17 * 23. Show the final answer only."},
|
||
|
|
]
|
||
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
|
||
|
|
prompt_with_summary = prompt + "".join(summary_tokens)
|
||
|
|
|
||
|
|
inputs = tokenizer(prompt_with_summary, return_tensors="pt").to(model.device)
|
||
|
|
with torch.no_grad():
|
||
|
|
out = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_new_tokens=128,
|
||
|
|
do_sample=True,
|
||
|
|
temperature=0.9,
|
||
|
|
top_p=0.95,
|
||
|
|
num_return_sequences=10,
|
||
|
|
pad_token_id=tokenizer.eos_token_id,
|
||
|
|
return_dict_in_generate=True,
|
||
|
|
output_hidden_states=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Decode candidates
|
||
|
|
sequences = out.sequences
|
||
|
|
answers = tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||
|
|
|
||
|
|
# Embeddings: use last hidden state of the final token per sequence
|
||
|
|
last_hs = out.hidden_states[-1][-1] # (batch, seq, hidden)
|
||
|
|
seq_lens = inputs["attention_mask"].sum(dim=1) - 1
|
||
|
|
idx = torch.arange(last_hs.size(0), device=last_hs.device)
|
||
|
|
embs = last_hs[idx, seq_lens, :] # (N, D)
|
||
|
|
|
||
|
|
# LSC selection (cosine similarity)
|
||
|
|
embs = F.normalize(embs.float(), p=2, dim=1)
|
||
|
|
sim = embs @ embs.T
|
||
|
|
sim.fill_diagonal_(0.0)
|
||
|
|
avg_sim = sim.mean(dim=1)
|
||
|
|
best_idx = int(torch.argmax(avg_sim))
|
||
|
|
best_answer = answers[best_idx]
|
||
|
|
|
||
|
|
# Dynamic TopK LSC
|
||
|
|
def lsc_topk(embs, answers, k):
|
||
|
|
embs = F.normalize(embs.float(), p=2, dim=1)
|
||
|
|
sim = embs @ embs.T
|
||
|
|
sim.fill_diagonal_(0.0)
|
||
|
|
avg_sim = sim.mean(dim=1)
|
||
|
|
topk_idx = torch.topk(avg_sim, k=k).indices
|
||
|
|
sub = embs[topk_idx]
|
||
|
|
sub_sim = sub @ sub.T
|
||
|
|
sub_sim.fill_diagonal_(0.0)
|
||
|
|
sub_avg = sub_sim.mean(dim=1)
|
||
|
|
best_local = int(torch.argmax(sub_avg))
|
||
|
|
return answers[int(topk_idx[best_local])], float(sub_avg.max())
|
||
|
|
|
||
|
|
best = None
|
||
|
|
best_score = -1e9
|
||
|
|
for k in [3, 5, 7]:
|
||
|
|
cand, score = lsc_topk(embs, answers, k)
|
||
|
|
if score > best_score:
|
||
|
|
best_score = score
|
||
|
|
best = cand
|
||
|
|
```
|
||
|
|
|
||
|
|
### Stored LatentSC config fields
|
||
|
|
|
||
|
|
The following config fields are saved (when present) to guide LatentSC inference:
|
||
|
|
|
||
|
|
```text
|
||
|
|
lsc_num_special_tokens
|
||
|
|
lsc_special_token_prefix
|
||
|
|
lsc_aggr
|
||
|
|
lsc_remove_eos
|
||
|
|
lsc_temp
|
||
|
|
```
|
||
|
|
|
||
|
|
For detailed training/inference scripts and full usage, see the GitHub repository:
|
||
|
|
https://github.com/jeongseokO/LatentSC_official
|