feat: support internlm2 (#636)
This commit is contained in:
@@ -30,9 +30,12 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
@torch.inference_mode()
|
||||
def normal_text(args):
|
||||
t = AutoTokenizer.from_pretrained(args.model_path)
|
||||
t = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
|
||||
m = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
|
||||
args.model_path,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
m.cuda()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user