feat: support internlm2 (#636)

This commit is contained in:
zhyncs
2024-07-17 15:40:03 +10:00
committed by GitHub
parent a470e60c97
commit a8552cb18b
3 changed files with 323 additions and 3 deletions

View File

@@ -30,9 +30,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
@torch.inference_mode()
def normal_text(args):
t = AutoTokenizer.from_pretrained(args.model_path)
t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
m = AutoModelForCausalLM.from_pretrained(
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
args.model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
)
m.cuda()