release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
80
test/srt/model/reference_hf.py
Normal file
80
test/srt/model/reference_hf.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def normal_text(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
|
||||
print(m)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = 32
|
||||
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
input_ids = t.encode(p, return_tensors="pt").cuda()
|
||||
else:
|
||||
input_ids = torch.tensor([p], device="cuda")
|
||||
|
||||
output_ids = m.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
print(output_str)
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
print("prefill logits", prefill_logits)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def synthetic_tokens(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
)
|
||||
m.cuda()
|
||||
print(m)
|
||||
|
||||
input_len = 256
|
||||
output_len = 8
|
||||
prompts = [list(range(5, 5 + input_len))]
|
||||
|
||||
for p in prompts:
|
||||
input_ids = p
|
||||
for i in range(output_len + 1):
|
||||
prefill_logits = m.forward(torch.tensor([input_ids], device="cuda")).logits[
|
||||
0
|
||||
][-1]
|
||||
|
||||
if i == 0:
|
||||
print("prefill logits", prefill_logits)
|
||||
else:
|
||||
print("decode", i - 1, prefill_logits)
|
||||
|
||||
input_ids.append(torch.argmax(prefill_logits).item())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="TinyLlama/TinyLlama-1.1B-Chat-v0.4",
|
||||
# default="meta-llama/Llama-2-7b-chat-hf",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
normal_text(args)
|
||||
# synthetic_tokens(args)
|
||||
Reference in New Issue
Block a user