[Feat] Add window attention for gemma-2 (#1056)

This commit is contained in:
Ying Sheng
2024-08-13 17:01:26 -07:00
committed by GitHub
parent ad3e4f1619
commit 0909bb0d2f
11 changed files with 320 additions and 127 deletions

View File

@@ -35,18 +35,17 @@ def normal_text(args):
args.model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="auto",
trust_remote_code=True,
)
m.cuda()
print(m)
prompts = [
"The capital of France is",
"The capital of the United Kindom is",
"Today is a sunny day and I like",
]
max_new_tokens = 32
max_new_tokens = 16
for p in prompts:
if isinstance(p, str):
@@ -58,10 +57,11 @@ def normal_text(args):
input_ids, do_sample=False, max_new_tokens=max_new_tokens
)
output_str = t.decode(output_ids[0])
print(output_str)
prefill_logits = m.forward(input_ids).logits[0][-1]
print("prefill logits", prefill_logits)
print(output_str)
@torch.inference_mode()