[Feat] Add window attention for gemma-2 (#1056)
This commit is contained in:
@@ -35,18 +35,17 @@ def normal_text(args):
|
||||
args.model_path,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
m.cuda()
|
||||
|
||||
print(m)
|
||||
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
max_new_tokens = 32
|
||||
max_new_tokens = 16
|
||||
|
||||
for p in prompts:
|
||||
if isinstance(p, str):
|
||||
@@ -58,10 +57,11 @@ def normal_text(args):
|
||||
input_ids, do_sample=False, max_new_tokens=max_new_tokens
|
||||
)
|
||||
output_str = t.decode(output_ids[0])
|
||||
print(output_str)
|
||||
|
||||
prefill_logits = m.forward(input_ids).logits[0][-1]
|
||||
|
||||
print("prefill logits", prefill_logits)
|
||||
print(output_str)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
||||
Reference in New Issue
Block a user