From 44b05d0eadea8edc0ddb386e98ac35d738b0ebe9 Mon Sep 17 00:00:00 2001 From: ModelHub XC Date: Mon, 8 Jun 2026 22:58:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E9=A1=B9=E7=9B=AE?= =?UTF-8?q?=EF=BC=8C=E7=94=B1ModelHub=20XC=E7=A4=BE=E5=8C=BA=E6=8F=90?= =?UTF-8?q?=E4=BE=9B=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Model: shibing624/vicuna-baichuan-13b-chat Source: Original Platform --- .gitattributes | 35 ++ README.md | 251 +++++++++++++ config.json | 27 ++ configuration_baichuan.py | 46 +++ generation_config.json | 15 + handler.py | 27 ++ modeling_baichuan.py | 607 +++++++++++++++++++++++++++++++ pytorch_model-00001-of-00003.bin | 3 + pytorch_model-00002-of-00003.bin | 3 + pytorch_model-00003-of-00003.bin | 3 + pytorch_model.bin.index.json | 290 +++++++++++++++ quantizer.py | 123 +++++++ special_tokens_map.json | 30 ++ tokenization_baichuan.py | 233 ++++++++++++ tokenizer.model | 3 + tokenizer_config.json | 46 +++ 16 files changed, 1742 insertions(+) create mode 100644 .gitattributes create mode 100644 README.md create mode 100644 config.json create mode 100644 configuration_baichuan.py create mode 100644 generation_config.json create mode 100644 handler.py create mode 100644 modeling_baichuan.py create mode 100644 pytorch_model-00001-of-00003.bin create mode 100644 pytorch_model-00002-of-00003.bin create mode 100644 pytorch_model-00003-of-00003.bin create mode 100644 pytorch_model.bin.index.json create mode 100644 quantizer.py create mode 100644 special_tokens_map.json create mode 100644 tokenization_baichuan.py create mode 100644 tokenizer.model create mode 100644 tokenizer_config.json diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a6344aa --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000..f1b9c25 --- /dev/null +++ b/README.md @@ -0,0 +1,251 @@ +--- +language: +- zh +- en +tags: +- chatglm +- pytorch +- zh +- en +- Text2Text-Generation +license: "apache-2.0" +widget: +- text: "一岁宝宝发烧能吃啥药?" + +--- + +# Vicuna Baichuan-13B Model +基于Baichaun-13B-chat的Vicuna微调对话模型 + +[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat) evaluate test data: + +The overall performance of model on QA **test**: + +|input_text|predict| +|:-- |:--- | +|一岁宝宝发烧能吃啥药?|1岁宝宝发烧时,可以考虑以下几种药物:
1. 退烧药:如布洛芬(Ibuprofen)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助降低体温,缓解发烧症状。
2. 抗生素:如果发烧是由细菌感染引起的,可以考虑使用抗生素。但请务必在医生的建议下使用。
3. 抗病毒药物:如果发烧是由病毒感染引起的,可以考虑使用抗病毒药物。但请务必在医生的建议下使用。
4. 解热镇痛药:如阿司匹林(Aspirin)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助缓解发烧症状。
5. 抗组胺药:如苯海拉明(Diphenhydramine)或非索非那定(Naproxen)。这些药物可以帮助缓解发烧症状。
请务必在医生的建议下使用这些药物,并遵循药物说明书上的用量和用法。同时,保持宝宝充足的水分摄入,并确保宝宝有足够的休息。| + +在中文开放测试集中的表现优异,继承了两方面的优势: +1)微调训练的底座是 [baichuan-inc/Baichuan-13B-Chat](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) 模型,是较强的中英文底座模型, +2)微调使用的是高质量10万条多语言GPT4对话数据集,微调后的模型在通用对话,包括医疗行业答复能力达到领先水平,在通用问题上的答复能力不弱于Baichuan-13B-Chat。 + + +## Training details + +training args: + +``` +{"per_device_train_batch_size": 2, +"per_device_eval_batch_size": 2, +"gradient_accumulation_steps": 1, +"learning_rate": 2e-05, +"num_train_epochs": 3.0, +"max_steps": -1, +"lr_scheduler_type": "linear", +"warmup_ratio": 0.0, +"warmup_steps": 50, +"save_strategy": "steps", +"save_steps": 500, +"save_total_limit": 13, +"seed": 42, +"bf16": true, +"fp16": false, +"eval_steps": 50, +"remove_unused_columns": false, +"load_best_model_at_end": true, +"metric_for_best_model": "loss", +"optim": "paged_adamw_32bit", +"report_to": ["tensorboard"], +} +``` + +- train loss: 1.3209 +- evaluate loss: 1.4953 + +The following `bitsandbytes` quantization config was used during training: +- load_in_8bit: False +- load_in_4bit: True +- llm_int8_threshold: 6.0 +- llm_int8_skip_modules: None +- llm_int8_enable_fp32_cpu_offload: False +- llm_int8_has_fp16_weight: False +- bnb_4bit_quant_type: nf4 +- bnb_4bit_use_double_quant: True +- bnb_4bit_compute_dtype: bfloat16 + + + +## Usage +本项目开源在 github repo: +- [shibing624/textgen](https://github.com/shibing624/textgen) +- [shibing624/MedicalGPT](https://github.com/shibing624/MedicalGPT) + +使用textgen库:[textgen](https://github.com/shibing624/textgen),可调用Baichuan/LLaMA模型: + +Install package: +```shell +pip install -U textgen +``` + +```python +from textgen import GptModel + +def generate_prompt(instruction): + return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.USER: {instruction} ASSISTANT: """ + +model = GptModel("baichuan", "shibing624/vicuna-baichuan-13b-chat") +predict_sentence = generate_prompt("一岁宝宝发烧能吃啥药?") +r = model.predict([predict_sentence]) +print(r) # ["1、首先大多数小儿退热药中含有解热镇痛成分阿司匹林或布洛芬等,这类药品虽然副作用较少..."] +``` + +## Usage (HuggingFace Transformers) +Without [textgen](https://github.com/shibing624/textgen), you can use the model like this: + +First, you pass your input through the transformer model, then you get the generated sentence. + +Install package: +``` +pip install transformers +``` + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig + + +model = AutoModelForCausalLM.from_pretrained("shibing624/vicuna-baichuan-13b-chat", device_map='auto', torch_dtype=torch.float16, trust_remote_code=True) +model.generation_config = GenerationConfig.from_pretrained("shibing624/vicuna-baichuan-13b-chat", trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained("shibing624/vicuna-baichuan-13b-chat", trust_remote_code=True) +device = torch.device(0) if torch.cuda.is_available() else torch.device("cpu") + +def generate_prompt(instruction): + return f"""A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.USER: {instruction} ASSISTANT: """ + + +sents = ['一岁宝宝发烧能吃啥药', "who are you?"] +for s in sents: + q = generate_prompt(s) + inputs = tokenizer(q, return_tensors="pt") + inputs = inputs.to(device) + + generate_ids = model.generate( + **inputs, + max_new_tokens=512, + ) + + output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True)[0] + print(output) + print() +``` + +output: +```shell +一岁宝宝发烧能吃啥药 +孩子的身体是很脆弱的,如果不注意保护好自己就会出现各种疾病。有些家长在看到孩子生病后都非常着急,想要立刻给他们用药治疗,但其实并没有那么简单,因为每个人对于药物的敏感度和耐受性都是不同的,所以我们需要根据具体情况来选择合适的方法进行治疗才可以达到最佳效果。1、首先大多数小儿退热药中含有解热镇痛成分阿司匹林或布洛芬等,这类药品虽然副作用较少,但也存在过量使用引起胃肠道损伤甚至消化道溃疡出血的风险;2、还有部分小儿退热药中含有扑尔敏等抗组胺药,它们本身属于处方药,应该严格按照医嘱服用,切勿擅自加减剂量或者延长用药时间;3、此外,某些小儿退热药中含有维生素C及葡萄糖酸钙等营养补充剂,只有当患儿伴随脱水症状(如口渴、尿色深)时才考虑添加,否则容易造成电解质紊乱。 +``` + + +模型文件组成: +``` +vicuna-baichuan-13b-chat +├── config.json +├── configuration_baichuan.py +├── generation_config.json +├── modeling_baichuan.py +├── pytorch_model-00001-of-00003.bin +├── pytorch_model-00002-of-00003.bin +├── pytorch_model-00003-of-00003.bin +├── pytorch_model.bin.index.json +├── quantizer.py +├── special_tokens_map.json +├── tokenization_baichuan.py +├── tokenizer_config.json +└── tokenizer.model +``` + +- Inference GPU: 27G +### Inference Examples + + + +| Human | Assitant | +|------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| 小孩发烧怎么办 | 发烧是身体对感染或其他疾病的自然反应。以下是一些可以帮助缓解发烧症状的建议:
1. 保持身体水分充足:发烧时,身体会失去水分,因此需要保持水分充足。
2. 使用冰袋:将冰袋放在额头或颈部可以帮助降低体温。
3. 使用退烧药:退烧药可以帮助降低体温。
4. 保持身体温度:保持身体温度可以帮助降低体温。
5. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。
6. 使用退烧药:退烧药可以帮助降低体温。
7. 保持身体温度:保持身体温度可以帮助降低体温。
8. 使用冷水或温水擦拭身体:使用冷水或温水擦拭身体可以帮助降低体温。
9. 使用退烧药:退烧药可以帮助降低体温。
10. 保持身体温度:保持身体温度可以帮助降低体温。 | +| 一岁宝宝发烧能吃啥药? | 1岁宝宝发烧时,可以考虑以下几种药物:
1. 退烧药:如布洛芬(Ibuprofen)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助降低体温,缓解发烧症状。
2. 抗生素:如果发烧是由细菌感染引起的,可以考虑使用抗生素。但请务必在医生的建议下使用。
3. 抗病毒药物:如果发烧是由病毒感染引起的,可以考虑使用抗病毒药物。但请务必在医生的建议下使用。
4. 解热镇痛药:如阿司匹林(Aspirin)或对乙酰氨基酚(Paracetamol)。这些药物可以帮助缓解发烧症状。
5. 抗组胺药:如苯海拉明(Diphenhydramine)或非索非那定(Naproxen)。这些药物可以帮助缓解发烧症状。
请务必在医生的建议下使用这些药物,并遵循药物说明书上的用量和用法。同时,保持宝宝充足的水分摄入,并确保宝宝有足够的休息。 | +| 肛门病变可能是什么疾病的症状 | 肛门病变可能由多种原因引起,包括感染、炎症、肿瘤或其他疾病。一些常见的肛门病变包括:
1. 肛裂:这是一种肛门皮肤撕裂,通常由便秘或其他压力因素引起。
2. 肛瘘:这是一种肛门感染,导致肛门皮肤和直肠之间形成瘘管。
3. 肛门疣:这是一种肛门皮肤上形成的疣状病变,通常由人类乳头状瘤病毒(HPV)引起。
4. 肛门脓肿:这是一种肛门感染,导致肛门皮肤下形成脓肿。
5. 肛门肿瘤:这是一种肛门皮肤或直肠内形成的肿瘤。
6. 肛门溃疡:这是一种肛门皮肤上形成的溃疡,通常由感染或炎症引起。
7. 肛门瘙痒:这是一种肛门皮肤瘙痒的症状,可能由感染、炎症或其他疾病引起。
如果您有肛门病变的症状,请咨询医疗保健专业人员以获得适当的诊断和治疗。 | +| 医生,您好白带增多泛黄怎么回事? | 您好,白带增多和发黄可能是由多种原因引起的,包括感染、荷尔蒙变化、阴道干燥或其他妇科问题。以下是一些可能的原因:
1. 感染:白带增多和发黄可能是由阴道或尿道感染引起的。感染可能由细菌、病毒或真菌引起。
2. 荷尔蒙变化:荷尔蒙变化,如月经周期或更年期,可能会导致白带增多和发黄。
3. 阴道干燥:阴道干燥可能会导致白带增多和发黄。
4. 其他妇科问题:其他妇科问题,如子宫颈炎、子宫内膜炎或子宫内膜异位症,也可能导致白带增多和发黄。
如果您担心自己的白带,建议您咨询妇科医生进行检查和诊断。医生可以根据您的症状和检查结果确定原因并提供适当的治疗。 | +| 药流后有巧克力白带是没流干净吗 | 药流后出现巧克力白带并不一定意味着没有流干净。巧克力白带是由于子宫内膜脱落和出血引起的,通常在药流后的一段时间内会出现。

药流后出现巧克力白带的原因可能包括:

1. 子宫内膜脱落:药流后,子宫内膜会脱落,导致出血。这种出血可能表现为巧克力白带。
2. 感染:药流后,子宫内膜和宫颈口可能存在创伤,容易受到细菌感染。感染可能导致巧克力白带。
3. 激素水平变化:药流后,激素水平可能发生变化,导致巧克力白带。

如果巧克力白带持续时间较长,或者伴有其他症状,如腹痛、发热、分泌物异味等,建议及时就诊,进行相关检查,如妇科检查、B超等,以确定是否存在感染或其他问题。在医生的指导下进行相应的治疗。 | +| 维胺酯维E乳膏能治理什么疾病 | 维胺酯维E乳膏是一种外用药物,主要用于治疗痤疮、脂溢性皮炎等皮肤病。它可以帮助减少皮肤油脂分泌,减轻炎症,促进皮肤修复。

需要注意的是,维胺酯维E乳膏并非万能药,对于某些皮肤病可能并不适用。在使用前,建议咨询医生或药师,了解药物的适应症、禁忌症、用法用量等信息。同时,使用药物时应遵循医嘱,避免自行滥用。 | +| 痤疮是什么原因导致的? | 痤疮,也称为粉刺,是由皮肤毛孔堵塞引起的。堵塞毛孔的原因有很多,包括:
1. 皮脂腺过度活跃:皮脂腺产生过多油脂,导致毛孔堵塞。
2. 毛孔堵塞:皮肤表面的死皮细胞、油脂和细菌堵塞毛孔。
3. 炎症:痤疮通常伴随着炎症,这可能是由于皮肤感染或免疫系统对痤疮的反应引起的。
4. 遗传因素:某些人可能更容易患痤疮,这可能是由于遗传因素引起的。
5. 荷尔蒙变化:痤疮在青春期和月经期间更为常见,这可能是由于荷尔蒙变化引起的。
6. 生活方式因素:某些生活方式因素,如压力、不良饮食和缺乏运动,也可能导致痤疮。
7. 化妆品和护肤品:某些化妆品和护肤品可能导致痤疮,特别是那些含有刺激性成分的产品。
8. 药物:某些药物,如口服避孕药,也可能导致痤疮。
要预防痤疮,可以采取一些措施,如保持皮肤清洁、使用温和的护肤品、避免过度使用化妆品和护肤品、保持健康的生活方式和饮食习惯、避免压力和荷尔蒙变化。如果痤疮严重,可能需要寻求皮肤科医生的帮助。 | +| 脸上长青春痘怎么办,能吃啥药,需要注意啥 | 青春痘,又称痤疮,是一种常见的皮肤病。青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。以下是一些建议:

1. 饮食:
* 避免高糖、高脂、高油的食物,如快餐、甜点、油炸食品等。
* 多吃新鲜蔬菜、水果、全谷类食物,以增加纤维摄入,促进肠道蠕动。
* 避免摄入过多的咖啡因、酒精、辛辣食物等刺激性食物。

2. 生活习惯:
* 保持良好的作息时间,保证充足的睡眠。
* 避免长时间使用电脑、手机等电子产品,以免辐射对皮肤造成刺激。
* 保持皮肤清洁,每天早晚使用温和的洁面产品清洁皮肤。
* 避免挤压青春痘,以免感染和留下疤痕。

3. 药物:
* 外用药物:可以使用含有水杨酸、苯甲酸、壬二酸等成分的护肤品,帮助去除皮肤表面的角质,减少痘痘的形成。
* 口服药物:如果青春痘严重,可以考虑口服抗生素、维生素B族、锌等药物,但需在医生指导下使用。

4. 注意:
* 不要随意使用化妆品、护肤品,以免刺激皮肤。
* 不要过度清洁皮肤,以免破坏皮肤屏障。
* 保持心情舒畅,避免焦虑、抑郁等负面情绪。

总之,青春痘的治疗需要综合考虑饮食、生活习惯、药物等多个方面。在治疗过程中,要保持耐心和信心,遵循医生的建议,逐步改善皮肤状况。 | +| 由于输精管堵塞会不会导致男性不育?,医生您好,我同事这几天不开心,我问她怎么了,她说她老公输精管堵塞,所以一直没孩子,不是不想要,请问输精管堵塞会不会导致男性不育? | 输精管堵塞可能会导致男性不育。输精管是连接睾丸和尿道的一条管道,负责将精子从睾丸输送到尿道。如果输精管堵塞,精子就无法通过输精管到达尿道,从而导致不育。

输精管堵塞的原因有很多,包括先天性发育异常、炎症、肿瘤、外伤等。输精管堵塞的治疗方法包括手术疏通、药物治疗等。如果输精管堵塞严重,可能需要进行手术疏通。

需要注意的是,输精管堵塞并不一定导致不育,因为还有其他途径可以将精子输送到女性体内。例如,精子可以通过附睾、前列腺等组织到达女性体内。因此,如果输精管堵塞,建议及时就医,寻求专业医生的帮助。 | +| 心肌梗塞能治疗吗? | 心肌梗死是一种严重的疾病,如果及时治疗,可以挽救生命并防止长期残疾。治疗心肌梗死的关键是尽早识别和治疗。

心肌梗死的治疗通常包括以下步骤:

1. 立即就医:如果出现胸痛、呼吸急促或晕厥等症状,立即就医。

2. 药物治疗:医生可能会开处方药物,如硝酸甘油或阿司匹林,以帮助缓解症状并防止进一步的损伤。

3. 溶栓治疗:如果患者在到达医院前出现心肌梗死,医生可能会进行溶栓治疗,通过注射溶解血栓的药物来帮助恢复血流。

4. 手术治疗:如果药物治疗或溶栓治疗无效,医生可能会建议进行手术,如冠状动脉搭桥手术或心脏支架植入术。

5. 康复和预防:心肌梗死的治疗通常包括康复计划,以帮助患者恢复身体功能并预防未来的心脏事件。此外,患者还应遵循医生的建议,以降低患心脏病的风险。

需要注意的是,心肌梗死的治疗取决于患者的具体情况和症状。因此,如果您或您认识的人出现心肌梗死症状,请立即就医。 | +| 怎么让自己精力充沛,列5点建议 | 1. 保持充足的睡眠:睡眠对于保持精力充沛至关重要。确保每晚获得7-9小时的高质量睡眠。
2. 均衡饮食:均衡饮食对于保持能量水平至关重要。确保摄入足够的蛋白质、碳水化合物和健康脂肪。
3. 定期锻炼:定期锻炼可以帮助提高能量水平。尝试每天进行至少30分钟的有氧运动,如散步、跑步或骑自行车。
4. 管理压力:压力会消耗能量,因此管理压力对于保持精力充沛至关重要。尝试找到放松和减压的方法,如冥想、瑜伽或深呼吸练习。
5. 保持水分:脱水会降低能量水平。确保每天喝足够的水,并根据需要补充水分。 | +| 写个春天的诗 | 春天,
是大地苏醒的季节,
是万物复苏的季节,
是生命绽放的季节。
春天,
是阳光明媚的季节,
是鸟语花香的季节,
是绿意盎然的季节,
是希望无限的季节。
春天,
是人们欢笑的季节,
是孩子们嬉戏的季节,
是恋人们相拥的季节,
是人们追求梦想的季节。 | +| 南极洲为啥没有北极熊? | 南极洲没有北极熊,因为北极熊生活在北极地区,而南极洲位于南半球,气候和环境与北极地区有很大不同。南极洲的气候非常寒冷,常年被冰雪覆盖,而北极熊则生活在北极地区,那里有丰富的海洋生物和陆地哺乳动物。此外,南极洲的生态系统也与北极地区不同,没有北极熊所需的生存条件。 | + + +### 训练数据集 + +- 10万条多语言ShareGPT GPT4多轮对话数据集:[shibing624/sharegpt_gpt4](https://huggingface.co/datasets/shibing624/sharegpt_gpt4) + +如果需要训练ChatGLM/LLAMA/BLOOM模型,请参考[shibing624/MedicalGPT](https://github.com/shibing624/MedicalGPT) + +- training script: + +```shell +CUDA_VISIBLE_DEVICES=0,1,2 torchrun --nproc_per_node 3 supervised_finetuning.py \ + --model_type baichuan \ + --model_name_or_path baichuan-inc/Baichuan-13B-Chat \ + --train_file_dir path_of_sharegpt_gpt4 \ + --validation_file_dir data/finetune/ \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 2 \ + --template_name vicuna \ + --do_train \ + --do_eval \ + --use_peft True --bf16 \ + --max_train_samples -1 \ + --max_source_length 1024 --max_target_length 1024 \ + --max_eval_samples 50 \ + --num_train_epochs 3 \ + --learning_rate 2e-5 \ + --warmup_ratio 0.05 \ + --weight_decay 0.05 \ + --logging_strategy steps \ + --logging_steps 10 \ + --eval_steps 50 \ + --evaluation_strategy steps \ + --save_steps 1000 \ + --save_strategy steps \ + --save_total_limit 13 \ + --preprocessing_num_workers 10 \ + --output_dir outputs-sft--baichuan-v3-production \ + --overwrite_output_dir \ + --ddp_timeout 300000 \ + --logging_first_step True \ + --target_modules all \ + --lora_rank 8 \ + --lora_alpha 16 \ + --lora_dropout 0.05 \ + --torch_dtype bfloat16 --optim paged_adamw_32bit \ + --device_map auto \ + --report_to tensorboard \ + --ddp_find_unused_parameters False \ + --gradient_checkpointing True --qlora True --cache_dir ./cache +``` +- 机器:3*A100(40G), 36hours + + + +## Citation + +```latex +@software{textgen, + author = {Ming Xu}, + title = {textgen: Implementation of language model finetune}, + year = {2023}, + url = {https://github.com/shibing624/textgen}, +} +``` diff --git a/config.json b/config.json new file mode 100644 index 0000000..c3c4f43 --- /dev/null +++ b/config.json @@ -0,0 +1,27 @@ +{ + "_from_model_config": true, + "architectures": [ + "BaichuanForCausalLM" + ], + "auto_map": { + "AutoConfig": "configuration_baichuan.BaichuanConfig", + "AutoModelForCausalLM": "modeling_baichuan.BaichuanForCausalLM" + }, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13696, + "model_max_length": 4096, + "model_type": "baichuan", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "pad_token_id": 0, + "rms_norm_eps": 1e-06, + "tie_word_embeddings": false, + "torch_dtype": "float16", + "transformers_version": "4.31.0", + "use_cache": true, + "vocab_size": 64000 +} diff --git a/configuration_baichuan.py b/configuration_baichuan.py new file mode 100644 index 0000000..c5967ae --- /dev/null +++ b/configuration_baichuan.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +from transformers.configuration_utils import PretrainedConfig + +class BaichuanConfig(PretrainedConfig): + model_type = "baichuan" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=64000, + hidden_size=5120, + intermediate_size=13696, + num_hidden_layers=40, + num_attention_heads=40, + hidden_act="silu", + model_max_length=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + gradient_checkpointing=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.model_max_length = model_max_length + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.gradient_checkpointing = gradient_checkpointing, + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + diff --git a/generation_config.json b/generation_config.json new file mode 100644 index 0000000..6ab7a40 --- /dev/null +++ b/generation_config.json @@ -0,0 +1,15 @@ +{ + "_from_model_config": true, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + "user_token_id": 195, + "assistant_token_id": 196, + "max_new_tokens": 2048, + "temperature": 0.3, + "top_k": 5, + "top_p": 0.85, + "repetition_penalty": 1.1, + "do_sample": true, + "transformers_version": "4.31.0" +} diff --git a/handler.py b/handler.py new file mode 100644 index 0000000..306fecc --- /dev/null +++ b/handler.py @@ -0,0 +1,27 @@ +import torch +from typing import Dict, List, Any +from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline + +# get dtype +dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 + + +class EndpointHandler: + def __init__(self, path=""): + # load the model + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(path, device_map="auto", torch_dtype=dtype, trust_remote_code=True) + # create inference pipeline + self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer) + + def __call__(self, data: Any) -> List[List[Dict[str, float]]]: + inputs = data.pop("inputs", data) + parameters = data.pop("parameters", None) + + # pass inputs with all kwargs in data + if parameters is not None: + prediction = self.pipeline(inputs, **parameters) + else: + prediction = self.pipeline(inputs) + # postprocess the prediction + return prediction \ No newline at end of file diff --git a/modeling_baichuan.py b/modeling_baichuan.py new file mode 100644 index 0000000..ab227c8 --- /dev/null +++ b/modeling_baichuan.py @@ -0,0 +1,607 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging +from transformers.generation.utils import GenerationConfig + +from .configuration_baichuan import BaichuanConfig + +logger = logging.get_logger(__name__) + + +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = (2 ** (-2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio ** i for i in range(n)] + + if math.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return _get_interleave_power_of_2(closest_power_of_2) + \ + _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2] + +def _fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float("-inf")).type_as(t) + +def _gen_alibi_mask(n_head, max_pos): + """used in inference only""" + slopes = torch.Tensor(_get_interleave(n_head)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand( + n_head, -1, -1) + alibi = alibi.view(n_head, 1, max_pos) + alibi_mask = torch.triu( + _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1 + ) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + return alibi_mask + +def _buffered_future_mask(tensor, maxpos, alibi, attn_heads): + """used in training only""" + dim = tensor.size(1) + _future_mask = torch.triu( + _fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1 + ) + _future_mask = _future_mask.unsqueeze(0) + alibi + _future_mask = _future_mask.to(tensor) + return _future_mask[:tensor.shape[0] * attn_heads, :maxpos, :maxpos] + + +class RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, epsilon=1e-6): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(hidden_size)) + self.epsilon = epsilon + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + + # convert into half-precision + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class MLP(torch.nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BaichuanAttention(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.model_max_length + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" + ) + self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + if q_len == 1: # inference with cache + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1) + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class BaichuanLayer(torch.nn.Module): + def __init__(self, config: BaichuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaichuanAttention(config=config) + self.mlp = MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BaichuanPreTrainedModel(PreTrainedModel): + config_class = BaichuanConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BaichuanLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, torch.nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, torch.nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BaichuanModel): + module.gradient_checkpointing = value + + +class BaichuanModel(BaichuanPreTrainedModel): + def __init__(self, config: BaichuanConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.n_head = config.num_attention_heads + self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + self.gradient_checkpointing = config.gradient_checkpointing + self.post_init() + self.max_cache_pos = config.model_max_length + self.first_run = True + self.alibi_mask = None + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def get_alibi_mask(self, tensor, seq_length_with_past): + if self.training: + slopes = torch.Tensor(_get_interleave(self.n_head)) + alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(seq_length_with_past).unsqueeze(0).unsqueeze(0).expand( + self.n_head, + -1, -1) + alibi = alibi.view(self.n_head, 1, seq_length_with_past) + mask = _buffered_future_mask(tensor, seq_length_with_past, alibi, self.n_head) + else: + if self.first_run: + self.first_run = False + self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) + if seq_length_with_past > self.max_cache_pos: + self.max_cache_pos = seq_length_with_past + self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False) + mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past] + return mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ) -> Union[Tuple, BaseModelOutputWithPast]: + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You need to provide input_ids or inputs_embeds") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + seq_length_with_past = seq_length + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.training: + if self.alibi_mask is None or self.alibi_mask.shape[-1] != seq_length_with_past: + self.alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) + alibi_mask = self.alibi_mask + else: + alibi_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past) + + if attention_mask is not None: + if len(attention_mask.shape) == 2: + expanded_mask = attention_mask.to(alibi_mask.dtype) + expanded_mask = torch.tril(torch.gt(expanded_mask[:, :, None] * expanded_mask[:, None, :], 0) + ) * torch.eq(expanded_mask[:, :, None] - expanded_mask[:, None, :], 0) + else: + expanded_mask = attention_mask + bsz = inputs_embeds.size(0) + src_len, tgt_len = alibi_mask.size()[-2:] + expanded_mask = expanded_mask.unsqueeze(1).expand(bsz, 1, src_len, tgt_len).to(alibi_mask.dtype) + inverted_mask = 1.0 - expanded_mask + inverted_mask = inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(alibi_mask.dtype).min) + attention_mask = inverted_mask + alibi_mask.unsqueeze(0) + else: + attention_mask = alibi_mask + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class BaichuanForCausalLM(BaichuanPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = BaichuanModel(config) + self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + return tuple( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) + for layer_past in past_key_values + ) + + def quantize(self, bits: int): + try: + from .quantizer import QLinear + except ImportError: + raise ImportError( + f"Needs QLinear to run quantize." + ) + + for layer in self.model.layers: + layer.self_attn.W_pack = QLinear( + bits=bits, + weight=layer.self_attn.W_pack.weight, + bias = None, + ) + layer.self_attn.o_proj = QLinear( + bits=bits, + weight=layer.self_attn.o_proj.weight, + bias = None, + ) + layer.mlp.gate_proj = QLinear( + bits=bits, + weight=layer.mlp.gate_proj.weight, + bias = None, + ) + layer.mlp.down_proj = QLinear( + bits=bits, + weight=layer.mlp.down_proj.weight, + bias = None, + ) + layer.mlp.up_proj = QLinear( + bits=bits, + weight=layer.mlp.up_proj.weight, + bias = None, + ) + return self + + def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0, system_prompt=""): + max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens + max_input_tokens = self.config.model_max_length - max_new_tokens + max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) + total_input, round_input = [], [] + for i, message in enumerate(messages): + if message['role'] == 'user': + if i == 0: + content_tokens = tokenizer.encode(system_prompt + "USER: " + message['content'] + " ASSISTANT: ") + else: + content_tokens = tokenizer.encode("USER: " + message['content'] + " ASSISTANT: ") + round_input += content_tokens + if total_input and len(total_input) + len(round_input) > max_input_tokens: + break + else: + total_input += round_input + if len(total_input) >= max_input_tokens: + break + else: + round_input = [] + elif message['role'] == 'assistant': + content_tokens = tokenizer.encode(message['content']) + round_input += content_tokens + [self.generation_config.eos_token_id] + else: + raise ValueError(f"message role not supported yet: {message['role']}") + total_input = total_input[-max_input_tokens:] # truncate left + # total_input.append(self.generation_config.eos_token_id) + total_input = torch.LongTensor([total_input]).to(self.device) + return total_input + + @torch.no_grad() + def chat(self, tokenizer, messages: List[dict], stream=False, system_prompt="", + generation_config: Optional[GenerationConfig]=None): + generation_config = generation_config or self.generation_config + input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens, system_prompt) + if stream: + from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + self.__class__.generate = NewGenerationMixin.generate + self.__class__.sample_stream = NewGenerationMixin.sample_stream + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + + def stream_generator(): + outputs = [] + for token in self.generate(input_ids, generation_config=stream_config): + outputs.append(token.item()) + yield tokenizer.decode(outputs, skip_special_tokens=True) + + return stream_generator() + else: + self.__class__.generate = PreTrainedModel.generate # disable stream + outputs = self.generate(input_ids, generation_config=generation_config) + response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) + return response diff --git a/pytorch_model-00001-of-00003.bin b/pytorch_model-00001-of-00003.bin new file mode 100644 index 0000000..1cff8ca --- /dev/null +++ b/pytorch_model-00001-of-00003.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:79792640f77f8b7b79b6c8e061f4211585bfe229991c5b946b29f13aa32d2ad4 +size 9972281188 diff --git a/pytorch_model-00002-of-00003.bin b/pytorch_model-00002-of-00003.bin new file mode 100644 index 0000000..c42aa59 --- /dev/null +++ b/pytorch_model-00002-of-00003.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2a5930165ebcadd847af0768fd9103e0a2da2e45eb374a2fce95eb6b275c68e0 +size 9947421360 diff --git a/pytorch_model-00003-of-00003.bin b/pytorch_model-00003-of-00003.bin new file mode 100644 index 0000000..7785283 --- /dev/null +++ b/pytorch_model-00003-of-00003.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a815409fcbbda9398b0bb5d554558ec7b32696fd7a35db4b3ffe857a413b6fc9 +size 6610200289 diff --git a/pytorch_model.bin.index.json b/pytorch_model.bin.index.json new file mode 100644 index 0000000..7afb49c --- /dev/null +++ b/pytorch_model.bin.index.json @@ -0,0 +1,290 @@ +{ + "metadata": { + "total_size": 26529802240 + }, + "weight_map": { + "lm_head.weight": "pytorch_model-00003-of-00003.bin", + "model.embed_tokens.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.14.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.14.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.14.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.14.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.15.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.15.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.16.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.17.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.18.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.19.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.20.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.20.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.21.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.22.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.23.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.30.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.30.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.30.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.30.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.30.self_attn.W_pack.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00003.bin", + "model.layers.31.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.31.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.32.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.33.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.34.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.35.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.36.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.37.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.38.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.input_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.mlp.down_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.mlp.gate_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.mlp.up_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.post_attention_layernorm.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.self_attn.W_pack.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.39.self_attn.o_proj.weight": "pytorch_model-00003-of-00003.bin", + "model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.self_attn.W_pack.weight": "pytorch_model-00001-of-00003.bin", + "model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00003.bin", + "model.norm.weight": "pytorch_model-00003-of-00003.bin" + } +} diff --git a/quantizer.py b/quantizer.py new file mode 100644 index 0000000..32d9831 --- /dev/null +++ b/quantizer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +import torch +from typing import List +import bz2 +import base64 +import ctypes +from transformers.utils import logging +logger = logging.get_logger(__name__) + +try: + from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up + + class Kernel: + def __init__(self, code: bytes, function_names: List[str]): + self.code = code + self._function_names = function_names + self._cmodule = LazyKernelCModule(self.code) + + for name in self._function_names: + setattr(self, name, KernelFunction(self._cmodule, name)) + quantization_code = "QlpoOTFBWSZTWX/mUzwAK6f///////////////////////////////7f////////////4C5duvi2D0Oj1ppVCJ2zQFYbnbsxmq20pAC7kEDb3Z3nWrextY9NZbavON7nveSRqszudmzAGGgkeh0Pewk881e3Tz13kW9YO7uA9AUUiAWLNW2HHWCE005Mdz3jHs1Ic7QNCQBNGgmE000DRNoGjUYmA0mEmJjIaI9JtT0JoaaMTaQ0aMjTTI1TzKMmETwyaJ6k8p4Ke1T0wk2aE0anpPSHppqNM1HqYzVGj0MpsTTUGpoCAAEyAAAmhpPSYowMk9U8mqb0mJtU8ETwCZT1DQ9R5R6htE9TTyRptQeoyHqA0B6g9T1AD1HpGQGgD1A0NPUAAAA0A1Mg00gmhKPU9E2SekHoJ5QHlNDEPUeoDEaBkAHqBoABoNABoAaGgBoAAAAAAA0AAAAAAAAEmoiIgmiD0maRip+qfpR+k9U/QKaZPUepiGeST1HqeU9TQ9JoANAMhoZPU0AAYnqaBoAANABoAAAADQGgAAADTQ0IAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASJEE0AJo0GkxGJoZNKeBoTCnpNNpU9knqn+ppmUnom1PKZqTaaTTwTTFPNJ6pj1BG0eoaMgwQGkYAGk2gjT0jBqaY0RoDeqZoNEYT1NpsA/+iBrt+OVIiCKqfH7N/e67XZ2Dx9tPHyWbW4gAENNTtyzk+/WdoU604SoXU0JgfqgQxVmzbfdmaFcVxQAYINDyjTKU1FCUUzUuqqptg4SBgwIAHYE4NwQOrbY1bOF26LUVuxYr3Hp4paZXaqKU1UmXO3K+IXn2hURrgAegAaTANS+QBclUN6tpvhn85+uTPCLxzj34YO8MIMg45eRAEy9IYbKxeZTRnTy6GpPLtVGWKKK6iuDLa9wjtSmUQREX6wHfE3JeTVZdoj4Hg/3cHlBdw4c4BdGvigzZsubPr3eTi2hs6tZz3J9zUVm8qH+FPwSx4Tdr6by/OA88iLHk34rWNt7fT7NwqqqqqqqrGMYxjFcdqvY2mXyh42c2ccxhtyvBHojjUlyAKRgbvAB6nhls1wGLTOrfGMBsqRXl9Bl3sOlvafSA7sDrmAQI+mw90af+bvJ8mwjP+RKtjobGNzbfl76iTHMiIIUf9oIoygqSG2NLn0Ys/mZ+hzufu7epmzbvP1t7S0Xo8TKK7q6G5MA8vTgBb7Bf/2kITSLsH7Xmfydz7ahAt4YJbBuAQJI+1M8DLJCQH+UPbv212QWIhcCKhBrR2eryfQYIiIhKE0WtbOQ7OwM7OxtURGbF28NBndi9ejVDVA3dne37uDdzrwINS+O/0AzQTCgUjfCAwkkKFMT4Kr0aV3DicVAelGBesGYoCRcLKq5iBFR6SzOzrAwFWDFVYU2XT1oFaRJk2JBDOwVk1LFZZfwY7tQBYMGdECFA1cLZAg0IlfCTCMgZ4afRQBNvXSuMORVUTxTLSTgMFoUtaGLIr524yIM+INSFFIOHQ4TG5NZbd3Su3Nu9raSLd/ueibSYpAL0D42ZkAtD0pnXrfTxYPBw+mAt1cKPCPmDNMCDYCBiQwmANVhdDjBwsdIKyfH1slCvWbJC4QO8SBxi6A+GEpDBN6UQnPaEvBqFk3TwChKSowEENpyAueDIFs6OxxLRmFSUFpjWgYpECgDgfVBJjhg4GGcI9CD0S3igCrdziS3ZoYHlQE+7AELdvbebTVsdRvrPHCgiAbSYzUN0z0SCshLjaUaREEREQQRHNKAgAS9o0kukdJx0ulaJk0kINzlUYN0wWXLLsmRgSG1BEJNh5sCuVtIybGlKUW29BziJUTpqcA8UCCLtOGU0hH17BYTERfPKhCAwxJqSSSMd+umawlsykXZiKHesslqlVDKEHPzFhIWwJHTfcYCGE9dQK9sKixjNifLkW1iLnyZo57BBx2jksXPYjcaA6Z6rlYTl9ocZHn2URKVXnY/Wsrc5l3aym6Uq7u9eu2szSbJgwhqPqfOR1JCCZl7/AehLVBSIXc9npUk8IDzrRCS9XKMeamSDmFxK6OQDhwNnxubbnQygQb4DEL6oD5qkkG6F03dyDAUJB/awNUoDCa3CmYy2QIsK0Z46BoX1N4kY8aGNFB8WZAfWvaHeUT4gYIjEsZBBARIFAk2jCTxAmpW03GtdW4WCN0bLJiiqY3ixmHAWRqqQKqgS2hlf8mwszkhUy3LDx3GLdo5AHGAgC4BogUAVgH4QM0AGAImwbS6gwANIep0rJIU3hBgaeKAEcnzfs+g/sJZnETvInDcAH5fE7azmr8EyIFx77caxbrDBC64CEU8wCqzAHPgkk4kiPREKYHn2HaoDBWCCrFBrhR+XpeNQkdbzCBHee2hW8EW373k/qd/PxGC2R+IO4vmNEAl1AE0l4bEvmnfd5/JYs5gl9XpgQIS7g/LAK7owBwgso9j0yEB9MRIBjqmkLdG5uED3tICA6PYXe4WItRawAenfJ0lCFupoGvajxuQC/5YQPnwFpgQBMNgBndpgVNJcyw+5vCJgHtWU0EDYk2HsvD8Qkg6ANAd8UQXGH/3X3gXgNDefHyaQ/wd93Xx87hWWtW0kPCQGR+KYiPeMQse27PdNLGwhlz8WJObSnEQyHJw1JmStJXTtIg0ZKEHrLZCXd1ljLGkkxtpsDofXUiBH0LLEM43kb2waJ26KZsJ9sBbxcAqzUgWxzogNFm4vSxjMR58r5Xm8H2+6ItGcNX2AK3GhDIMzSX3YyFsbNG0u0MxvZzGFv19k2E45tXrK+1OKUYRiH2OT2Fs7kqtxMDrANVp2nxreAZg02UaFEsuf6+urQi1PxvNOhuacrStndOnonV3e5Du+Xjp8mjhiHYPNexu7UKSbt0Gs2rPIVVVSFyQ7phtQ0ZOUySoyZA79muzuLBZaLAW20gZIeuJDacErguFE3e70svo0S0mRBMBu33rjqVrNEN9A5PHvOgukEPEgb0tYAMrvcvIXB5ydzJHXQ1n+t7BUI24oJtSCTAUet75rBpXL4ylQ4LGBpbQeQCiOku+8rq90o18ga4WEGBDhvHB0YYd/CDLIMdDh2cO/i/RppcEi3Zd+CCU8OdxAAiOgi5qeghJkUnO6YGZi5LEilo2WhSiEVsU2IK7unV2rXG61Q/LbUqGx72rn2Uzx/q/fzsCWUFCQyAA+XqfGVGvL1kml0MVpjJl1A9vYoYTSatnV1+z2czsdoc4QFWLILHn1S71/r3V1S/fJMgDlXX6DVv8+FeECNi1u8zf8K8r1Khq7twFu5xPfZJT+PLpYUZWgGNDG0Jlq4rsQy86u95xqTdO0TbSGBdDOUSyyGHQAmP5mgNfVvgeY2tPzlKbyrvnaZhgQ7aWeJjzbF4mjPlro1hYjmnWUshKxVsQ6pveK850taANOgIE/aJvr0IAC0g2H2d1agVwnBkAF1kl7IPZc8mBthvlYish4AqABgI9hw2cExRabO+8Xz31+enwlCxSbnfVFlqig3UKGBQiybpEBGQLIxuoUMVYLTt53sY+lPlxSAq9f3lfnVlFmiBFrOhAeAF/0/N6HI6/+rsQ2+D5U5fenadDmtFFgeZLLESwOgWWIlgWFo+uFROhke3lKQ4bf0mLH3XSOgtDGd73hfMwDM2aF7Lonl7AlbiPbV2zY2lvu1Vj7jzlmFYoKieH93wt3fLhBXgYUGJEjga5YWEVyE00qIYWXSKd0ZaZy+vuCQlhaz5ELs9n/pjuFAHpoDCMEEtseECQF+Rk58EyW3nzCdlyCeY5WPItdkDZ4egXmjfZTLSVT29ku6KCGxHbdTBD3z52SxkuXkpoaHyy3t25+JwX5zFdYawDASl7397IB2tunNbt2FygaTBIO5qrG0asQmxEVRGCn26UX6DewTmic/QqkLZjdCTqjQDGlxy4IODucyQlmE0zkwSkR02cZjZcA1MzMczZAf1hfPnZT1IGtWIJGOcpzgYwCGyiNtoxRkupRElCCAgWJcE4igRJEQogPHYVAVBAEYDBkUEBIOSMK3KJNwQllpqWZARLCgMM8TkQoHOSZTDbSrjS6QtkYsQSloWSmQ4BlMjEJuuWh0ERMIVRLbcNDDQalLRQiEoBIUKZaiQpZQ1KoooVlNtjVVGAsG6WkNS84MJcoYIgjBrKaODOaUZG6QUZlCUGKy25MUVYGMWC+95zG4FRE0iyDRISulc0GQJt6m5u8WSQD4NAiDAMD9y0Q4TBGAaAIGe6PfdX9zl9Xginufp+HmPiAGfY8ZoDAarMoQAD9kA2OUJQV3lBq86RzpT8nbXPtqxsvN4YTDyOQgGEarV4Tc5h1yv2Npz+65PJpxO/Tefe5S5U1n8asAC3AQIACrUA5XacxgALbHvUfi9ApR956Do3PCWymCzTo7JjufU9DsGcQWqAFwwZfDzR+m6436pzvncYkARkLKOxX23RuLsQeK067Y/Fq8tB7igBMvb836/03fkV4qZ5YY4pFxADLifQb2iaUAwjesDs8Nhx5vnIw3rZOyb9+jyaYazgr2vbSKuf82URMcyf+99L2sWJHqW/I0PfaMR0KsULcnf9Lx/fJFzattuUwcjv8vdJed+FY1s49FrvJMbRVa82imzbdgSpDhEtleDphWrjgzVu59jsXKG/3f88zolkjqRQUk+Xm8F72190OzfqwfT5XAYbvq8WBzq/B+4rLP8j5PDfiytkicVOAAJ6QOe+hWqqwgfq61qtJ7jrsz89u1dDqsK/9Wur9Po5K1vHsXseRHoyF+LoewZ3uHaanw5S9LCW9Gj8k3e5ObY3NfjabO0cbzotaAPB3XIg+av5zaHst8ijMqapTpVtdwy211QZINMi1UCIHnAB3ZLFDZQuraVlNALggow5ygAhEo9EDHUCSm8+Hhev7eTufm8onZ7pATIUwBEBBUUEPBw/zcrl+pwtDJe2XApoPk8CJjTqtqbv7DYwZWFs/M8EhDcYE8AK8A+GfX/aQkYgSLdftV0Id/5gf3lOuNNC0799E3uYYtpMg6yABaJz5en+HpUfveNBXeYA8Whj8TtZK60F8V863ndv3PwKagCzpXtfv1APjaUgxkGLtptiZPR9vldS2Bfy0pT3RXWJlLCCj+GpAz28S4v0YQrYE7We9WpbVXz7KVTWEtoXM/UPZhYnpzdeokWJdNHQ6JQLxp7bOfci50rBcdOdhOqmyeC7B2rL6rxd969Xxc9L4zMrsqZ0+DoaPeSn8Y5QMLTOLpdvz1qaOO5xT1xPjgKnhTYa5pzi5U+bDcHXzYdxpgAbbhf/e8aBprxka5aM2J3lYXBG5G/r7CunzcPyjz2o79z8eDKkMvdO9WixswXLu3TkpoYcV0465fwUxoxC6L9Zwc+QsLDfqipk3wMSSRkBPM8Bxrwt0Mjr4IWW9Tw+Kw23yTbUyYJqrgNaq7saBKAdzYXMQ6mkrfqt72Lk0YwiZmIKkXUgChISCZMMrwdnjWbJDoR5ZXGxxAX5uRBfHBOk6JS8VVVWd56zxf8v3uR0/zON57e6BDuqIcQDJ7H0q5BNPaWbExYw2Bj4tRM9kB+JfynyyEfR/7ZiPXRFLmwpGGjLF9G6/J65mkUZEaKrUdBZYUxFKqGJL4LAbEfZjLi4GYXhv+x3ZpHkC3YADdMsKeYmfKgtzUd+Y7dVngbdcEFGAL3VqaYfYAYMtY3YKIQumTVXUFTFQyU0bqIeMgV2WOcZFXICpoMvueYVy0mHAiaeyNg1p5/QmSbYgyb7WQdUPfY3QeKc0hewGB2z2vH9t+pvy7B6P21pG+wXCMQHZl30TJonLPhQg8nka+raw1OLPUVWvIidrloKjcLH6/YAwepAoWEykQ9Bw2+YU/N5dbXnsNcPbubOszstYSwQYATYulLN0AHAgwb5t+VfATV6uhICgRgDGUaoVNNLc9ZMMW5+qKVhOyoRMLzJolo17ACLDPes+aoyeD5aIZm46HHKV7KqGX1IGbYEEDaAh0Vj+43wIMep+e+gsP4UEgVjmMAWTPz2XZhQDA6/Vzbk0fK+v0+bNB12LRbfmsufKzRgw7Hp7b+J+N2LqWXdwWTvhQ2rIPjc2cgS2A4Ub7IflPitJFAPyFvbvHK+tXi0Zcbi6mO6HTaIydOeYDmSYUIACAZwJCEgueoJnU7W6WfGdWtl1TdD4WHQ8AgDnmNUD+2YrjxNum3+1R9B+XSiSGrVLcFrVC/Z9R7D8DslIGyMPXbJAFthAMNYs7OdlqPilZtnwtReItC2Ff5vD8mQHwayX/vh1LB+HwoefoZ6LWUKb7WH6D0FmEhEKgwAayAYsoKUCcPepjDQYfA2TMWHoiS1lspYmEi2HdFULic/ucQlrFCCwPxyDeITAUsiAUFggCtZuDuVPLvVtM4WCG6DlrLwBL1JAaQFWuf7/uHZ1WAHEBuz9BMrshS8OhZpwrmYpgUIFoauEJQxtrw2iu9bT1ZLik/F26jhZblz7739qomvexIWc5hKq/GfFAebrnq/23mGuisbZhiROtNdFBDwqCBc2zrTYMfhMPwIF0s37CzzvYKeLjIfQZ3D2N6o+FRgDOkDGFGjCDiy9cJBVMOBWJ1AjDIxTAz/LwSRYuyzhHyDiECf0P53hWshYcMslf0PC0tWfLlUztN1xTxhwgkAudx+IE+NuS3phgEhRBo5lXEG6KhGydUzSU2WphfuFy0VkjH2AIPddbJ679s70tkL1rBEEEEmFgwK5pRCB6ZC5EX7ZCkCTI1pQUDJAwhQoosjBZFAjelFmydnwH9j46Ei5DD9ZaOvgT54UpSh4mD7FR2rjbJjFFdyOauUAjNr/DYBQJkLsUsd2mAXDIMHOuu8ULJhkx21G0UL7fnlqIPfiwdblRpcEaxVjru+6bHpdvj38qAOr1rUACbHrKGDWLFjGCBGYoGREGZBh4aGauRARRTmJdfJBWYoCDdFrBtCgYo6H8NyRIvFfbeTFjxF9riIiIiJABkRljjGMYx1mizcSoJ9AAFqKHXgBBgYnYjs06fFb2fl/bceQ8TeN4h1jrKPd/Pbtl3dl3fnbu7u7u7u7u7u7u7u7u79ZxeoA2gbgjyqd70779v47Lsepzo6y18vJkhQMaDKDNhYbWPpJA6hsD3pzguE4gtOhzrtDoDA3oMbPVBY/3fi0DbkWt7GQwMw2BtpNpeKt+v6KytGxxqCQ8JoLCGKIALFxqwIOeI7fqckjnW8eHjcW3xehEp2SWhvmrtDDdoBSOn6jSjQCgLuhd+EBOwr3q9GbUewJDA4QvH+DpFwt+JbtP30yJTy10KFMLT8MmAGUKkqn3DQHSmTACxjEheIpDhGuZT/WrsHgP+ly7Bsto8UYb2bBvwPRV1O/WaEbmIEMEbQtfphLgUDADF7nayfXs1CXBxYOi1aG36B7rr5EX31tzoym2bTIWw0maxvM3Gs+KAOSMztimS4oGQokBRf5dGKNykDp8tH9chWc9k7/6I+SxG5cZSnx52CFhoDqaZ8wBethxjRVKaRfCZTeBpi6ZNdZFjROy9x6tdgMem0rtuH6wbAz9tKvlhJ0JUP1e+2xVgroJFw8tQxLPdwVnLVMDu+mmfk9b5mK3qMNwiMyBqFaajMIgCDBYUXbdKwwVVhoMXL5YLkI5FFviIkYQTNamuapRILAqCSAYSsIOOVAtAUUrDwBSthRBgyVAM1wBrIQhhTlJKQIwFnj+b+aXuJyerhwx7HxQLofddtH71c6UuefecFIrANhfgkaIt5KL4iV43tMeP17BD8D7Dl8+AQTGQfz/rp3JWOfDodJOcvDAquYl1QQiHknUmAQ3lYpRUtJEUowXnnJnOZjZzdINlj+y7lXBb2uPR6a2E5AC3S6dBaJxYl1qyRXwQ15QflVkAK8AmAwql/n4frTztb/XRXV9J3eXRfv0MuB1OShRrtbrfdudwKxsAYC+QHiNISbAQu46ffUU/Flrw68uJ5L+7p69JjfglHs5PSd0bjADZeFsIWCqy0kQ20m3CskYLPShb0aoDdHoJBUQVEirAUgeRTtUBwAa0INXTIBPMHp9AongtXzSfuWCFQfDtzRuYRVG3WIXUjEg7b2vBZKT4ESq2tTcMyGXlqZN+uJ3CaGHEJB/3Q6/xrGIGIxyzCG5tLlSXx61sy0Bra4IFaYrjF1zJj5JPK/SslbN65uYffnqtyIX9zren+rrSsXVVhq8VZ6DFpnBVlD48AoMeltsyGSZSpdUjR6bM9J+oHRVmhpp2HBv+N4PXeS76ctP4LOLvreBzzyCr2v1K7eBo+dr2gwZ2x9k6EpHd7pNRl6Pv+IgXtj4WmtlEUQxkzWOVcT6jcLrhax5PVvgurz9q7DtdWriVdnpnTlTrQqdvWN6ZNr4OdpMM/T5Gg8irLXS/YOgvhteS49VEj8+IfNiPOf8MfMkUw+lYehdNxKZnNbjIoJiqRY1KVGIOWpRtq4m6GCyiypZKKzWBQq5j8RYJE0NCiyjJmgUmDBi8BoJgMVJYXMF4aGDL2XQ4HDKaRGaGhctNBrShK0bSU1BpFoRaTkkCCUWaDCx1MUXQCaGRhgoqhCHmzrFyZwUFG27KVdmNgbChCbZNAMghZRoXKM0CMEXaUTZswtBpLoCkxONrpa2wL0qn0mw2eV0yXs1MGgGSTcAo/GELIbpoe+8gKSqpV0ZIoIa4UCcM2EdVikuAPuDlU89YsXrb9Zb+Pr/F8NexBBbEwTQs9HmsQGBYPoK6bZKDvj9yyALrlOaMbLpKxRM+njvB4id/1Y1WPm3K2A0BVSlgWJNjYxne6JZ8mZfv7w1Nm3/GFOiwonktduZaRH2loGGhNBUlQiHENkybM8pBim0iaXcpE8dAF4GodlriMfOGH6hHY20huVvSlLDBRKHQ4Y3SyKrmCcy7ZZMDyNqVWWwpS+RHQaYnmEURGCKmQc8ARghpQffVMwK2vz6V97O+59X5foz4jUfN33Z49cKeKObXDE1rNvV2QaDOLOi+R0fl+RM8jVQ7QgNiDMzMgUCLlYO71Vn7X7vF0UcSZX1pu+s+xC4MZXNQCl0/rb68aAY3rOJ/jaw7EOYIIlln6V+oFpwZLOUjUVHfe6pdjXgAqsD219Ri16edZ03hcjePW71C29Wy0nTw5YIfs/Y9sNovb+v8vA1P7beB5bQmvEv59b+BnUs8yqQ5/cLKV0EZRMOGHmpsMrPidWDXTyP3fuO+w/9+kbujeEbdg+n4WXJQBn1kL3Py/M1JnkOu70oufaRPG6bsd6SUhq1TALBZAhKpoyMIvkQGRAzJD+udGR9e+WlVzjlJeqELl+D2smL4vG6BUFpiKHDwqftFBbX+9VV338vNg+5kL11bd1yrZaYZrGW36mrUIRi/MVgrNNITCj++zpFSOrRLE+Prlr3mYOP1TtXvtpOwLP5Kmt+3zZvXSsOXW+ix6mXS5mb1MnTvW0u8yHF356RuzXUyeGiLTe+IvXvKmJrEymIxQT9QMSU8WTHgnJi1BgP/WoqICgO21v9Hiw8IaXJY1619oEj/3cb/7R/nddLm6VA5xoN0t3XY6Hiep4VGnzs/Od0hj8f39YuAC5HvfwvWuOeV5fz820AAGglyrLFDjUrv//M/fwNdsEvj0MrTXrV8vLZfMvKMAzJ0/Sda/28/N0QniGmKhoagYUYMGp8IFDrOoi40L48r/SLxfSSDw9TM4P4vUeHE+iTmchyj7Vmwp7m7dejVSNZx+2Is5jzuf+HmHr2aml3fWein0wnXnxne72A86Cc3hrzXgbfc7lNQiJuGMljn2Y8pgXjrTczIy1teeafy8Tz8vmzBWAAFXfojX/x4Kv/YFNprgURbUBytnsI9/0WeuKmZjrWcumUGQgRDIEUsAwZkQMwPsGTJjpTEw7YAwCs7Oxn2XE+hexXn+z/L7HC65bJhCR3SxMdHngfkGgqJnhYzTGjw9StB6E4VI6SgkdNEdesLFW0cgxeYq7YABEPlMspZSBtZDQYZMvK9Cbu/UzXvja7MLlO4BfVYkMH5dwAfQ3u9WEkCoveLyp86iGmleemxREJQ0NoFyWpMxsNQCuuLGCdP703Uv1a3JeT7vfpxp8J+o/ft+J70dz7dV+1QEcxyT6REE6vsl2+0Yd8ayjKWBg2j8pRTeGhVxiYZDc6/YatrSzsw56wbWzGkp3FLpa8+60pan1LSvb+rcfyjTyEM7yC5BVyZL4r0qVCMZRc+AMHxlyZMP5QQiFATNqpVSdy8i66S7oSIl4APKPMzOTus/KeI8rrY6qBkuRSWT0y7LGvNz4KBjigkR4r0v9/bluxFmxePnvZRhpjgezOiX6bPa5LZkzsaLjmf6NzPP1ZfH9p7j4MsQL0YMETXjeb/5lAYcJWU1RECXppb+33HdO5Etl4xLXPxfV8cGZ43FFYXKVoMFQHssoAIzyiClcZR8W8vqiACqmcw8DAwzLM+FeLFaAYRiJ1DFqKh2Fcs+6Zd6erYKNpF09oZhCZNX4DO1OL94JPGTBXIPMmPjmDb0GlmwFaWG2CUqSjhc20YNd6Wwzu52BklGYvDcMnERi4Yh1wqwcOlqiLatNe4rj8FcXDxqMSsgYP5/FnSoTq2VVKttXQ3Gxq0q0Shp+qCbIAeWxu1Ynpd88H5zJfn/V+v+5/N7nyR7Q+n02bmML7aF1Sg+a32Ud2eQx2a8dQqTABf2SKJgvKADJgAJV8Rd0Wt1oIVj9nr/ZfC7fkbdqnS9R4eIbqH2HVNjOYdggfFeSAHKIkaC5R2rzEzdxs7dDCzizsiB7OluhJplyBBWKXPmS0tsUNnNs2D8zfW/QTSAr0EcsnQ/YPZBD4D0rHa3rkC2DHq+G97XfliTeY63fQow3RQpyKsCFgdUC2sF7aep4TmSDjlnDDpfIUJ3Ne7AMT4D7xpuM+j1hXBxYcyIpO3bvLubMhwY3Lrr6KfLP4PF0tpDjMOew5rBbSSUJPAfRMkDCSBum/B7S97oYaYZS56rtu79Vh408mfXcm6HcL0Qe7fRiqav0GhPcuxMpZIm/WHpICgBUirY8aK56MaW53+L/x+BbXNrjaySqntSLsoHFEiExu5hX7+yaqu7Ss2LrWVpPp9L8fuVDJdVcPqIQRFv/gWlUadkCUYMxFQf26Nlq3czS1/zwLAGILGRazcevp3q9/0O/YUWwXKvQTQghgHliLIIbcY0XxVr/9oV2++gsQ57NkRK084MjYapPJJ6Gd7WONsJRq6iIJo0GH/kO9e74wvERAiMW7UqLI+2obG59Xcazzvdk2UIhBDN4V/KqrwHJ9EpMftxjsugftMee96M9+G1DfnomWt7OmvNC5TP5/Fa50GNfJjieHFJ0mwlIothDYzg3BQyahykpudGZEmgiK9ViiKhI9ypBUuKuau8PitJWe1r0kVIrV4VRDTDa74vSvBytKDcNCzJ66Oq5G+hTTGgbpBMS6pJTOmrIjb0m9HsPvrI3rQhSkRYc1aEmn4+CFS9MpIxTpLccqtp+dpwTDqQfFDvleEeOfwGuSJEiR4QBtGkWjWrKysrJEiRI3Pd252xBk1NTBRRRZZZZZZZZZe4EJvbjqWGaaZgEypipYBc9da7d615Ozv+0TPBMoiPZt+OB7H2evtWBqyXzg9jgyNarCYQHxeABDu8KyT59xFO4fpXed3nMVTnQhwffnGz0DpW+c5RkbdjYgCQgDV6Sk3OZyVhq5u3M66CH4jQq6byDLwIv8D7ipARoPE7/rm7y2+93QALi1QT9F/QCxMDOQkHeUdC+o3NN9GXve/W1Ua/wcVgmxFD1YTuKB+xQIiSdMyXLjSbjWwNfsJH8DqADRWZHIyjHLolbAN4CAMrT3YQqcfwcVf9TtpcgPfzwWRN7XWJzrS1KzOVWXccRQ+9TusY64JEtzfyHJnKixBwcbgCBAgQiIiIiiqp3Pje3Y4/hFGgiIiqrTGMYxtsZSR3dlixYyrLVZTH79fh8yNTc4ezofRU9vjHOIATEYEQNb4IG7bzkD59jIzRNInn9c62cuu1ZkYpfHu7uokt8nd1Hc6ApKjEt2qqbEG2l6oUPERCkrFLjmUay3EPnj2vUe43MqIYdrm3PZT7WrLfnw7y9is1SEtuI3OsO3EW80l8imWVq1Yje2a7qnbRVNK7eZSUzwnE6j9CLm24oqbZ35UTokBKroRjwJNyCBEACLMRjnOy84O5zJREd0g8Xa+y0W7O3tcCI+46EvAjDUyqYnOCQAfEhYjlWVo9HFVl0Fk1g6rWywYXLyW9gmyJHKcFdans6g078Q9ryUjaXacP7/PvwauCguS3VK61FsSTIa5RZd+GJqurSiskfDyz7d0Bd7WxYHfJfTrpTamo87sRYMCEdyYaUdCzhu3027ABTtQCAnwKi9q3KK/rIpk6zEjGHEvADnOwuJ1nOvPr8XZNswFPZ07G/LauwBMG1tOWNT76s7Jw1OxxW1BImaJT6XUIQ/1VPRP6UZLBjAVwit2h7xS6TLbCUnzPvqOrOfrbFh/ZAFnP7jW/zIMkMNMUk5C20iKshen2HLTcv3ge8jBXRbUso7c88qlYXXozqDXWcHg21XXWzupu9YmNN2aY8W/tJ3ru1cs4YtK5b/YBitp4WYoOvZCpCIC0Ju2+xw3MABgLVFBetW9KA2pqTQMLlkKFfMNANN6+JBLD7W6/i0AiMi2fIgslxtlD+bdgBbDk1FxvsbR+npU23xUVtnBjvadzYRwqwnvWSPbrgxgFM01Y2yuGIJh4HBXDlmKSUokWxg39HUAD4u4+D8ivAiXNQkqnkKxTsDkVM+u/s6rx/w/VPZ1yL9nnzJm2YZ9Wl+9izPDiRnfzWU5Eo5duybQnktKu3b+J3pVuuBmmnebBXfiZtkpUjLRKvtuhD3GDAd3t8lPpMQgVQmkICwxxqhUhLQMPWxbwjlswPn5rmN8Fi0j25H0DYQMgIsU4+OvNxfxINfZR+ndisEVJrn6M1cgs+qsqW2AYv5gIBUG2nAI2sRJdPp0pkIFsJQ9DC0Exajuxg+5pGLShRHi9wPxlNGkITynkwYgPc5Bjm1ceZiqsTuXbr2ZrcqBszMKehW3A7cYHig2nqO46ef4275H+NjUxZ7Yxj0XWdJ+CBStOyj3EqZrP6f8049HRTOibY6aHBkysu7Zy/0S6gyH3v1st5NJVth4dqmwuarDr5z62e9OpPUqH6te3WRJmOs5XNggNsBgGGgo4SSlh/wYAXsqj3aHIiODcmQbAbQltCKcIoU5klptJHQ0l2P4Tgjad8WBWp9XyPm/j3QYeU5tV+GSJ4bCaYcK2PA4Spq7rr4bGK2La8fhcB+ZpbeVZdDoKcxwCBZQgvQmADvnSmoonhrOe7esVg+7JS5aUYwMCekjlC6YlQHUxfh1evKIB8OGrutYZ4YX41h6Jq6hHuvnBsJnjhYHY81i95iJiJTU6/T7VS3gB1qH0ACm35YBe58z7ceWShP5goYAvCcHOTphatcimJSi7e8cPtVNlLBeanev47WzlgmaIlrfg8PQALIwuyc+Ce7PTEdI6IMaL62wH5dzYaANEsRgmxYif+uWKupAwqrJ4eXO3BFsHrOiYQRSnB5GwA01qir3ZWamHuBtKIrzLS3by/XYFMY2AJEnhaR7ycHZFV8q2AKplu2J5dsQ24LL0qZisABXaOzHlwBFOQv0vOYWldhDsVt5f3Y4pEAsNwPQChB5QmJB9EYeqbx1Mx3plDVGMY02NMYxjG228wkHXLQBuctwIzDl0DNb2d3Zr2eV57mni8HxuT3pPieEQB9MdPlRq2ASoAJ5D34BKD2+jwhMSM3k9e3pXf6aOC4LK2IgIYJ4xQMEhhPzy+0BRQRAMTrG+uVq2FlPAAWvayCMW6HdOctiAZvYzmADuOlcPkF5QWJAaMRsb5I0Onl1kWwDFstny1tu3cPUt/f34gagGAiIG0z+LwJMwuBjAAO0oXQ+j2OhzkkDWu/H1iOt9LZS2d9xud3NjEIOUBcEGiLbYAIhuk6kG3QiZ7Vx448qOR0823ux6gaDAo/m7VGENCDY55QyihE8PY2c3FAOq0eB5VrR2rVOD8Pk54g10gYFruoShyCA600IlGADNkNWFwSUq26fo1MfJozZb8ivAWwKtUCnsIy1VVc6gilxgZXuOpIn5NqpQ4t1rnTCc+zVGQ8dLhuE4NDF7wA+sXOKNy3yzCWV69Yg3C0AUAEgSDmXcoIVu+dFgcdgdaEhA+iWl1AC/p9ikx5Lmxupjb3zEXwOwav5pXeGFu/i1uQdRtu2CBnIi7j7vIXJ+0+JkKDrtuikSysRrZuAkIPGGIXa2KOvhm+tzKtliPPcIGhgwSePz0mjUO5L7zzmcZMHoTM00cmhmTJXLHXXVL0wJj4s1MzRHFFiZHJnI5xbqYKxtqajjQWsuDBeCnFPf3bjFXVC0XXPfJZnZvcUOvlJ5TfVc9np7+YKcF8Pr101cACqIsDSQrhevDLMRutoELrdyRd4yc4EBhnWVGVUo4LsLWMYimrKjHNShUXacMGzWd1rteL0aqM9Wd9vU8jWwVgD0CDq0ypYdiu5V1wDsEFjDwLXJ6pe46MvOgOONLlAwPQwQmNUX+2AdnCCSJdjtaAefC8AY7bANwtVktFIQWVBQ95dSmjz8VnKFc5xsXgOQl3TQHPvghbPELlyOR3/IjaKbR4oXeqF4EjmEktr0SghMIXS60jhlBQIfEIJnyehMgiETwigxDpiHows1RgnEalhk2EzYwRLmRwajUmIaCFSzCXWStGaaJgaMaFOidK9crUyN2ZuYmDCMxbjQvOVrOaRTDXXVeCjhum+v9g5xzwDtdCQ0k+kA7IgR/IB4DE2B6gEv0Dv6l1YUCwQl4cgIQLDp7+vyQ0Ua6AogR/cA0tRku3sTszsBxdKvDwb0HSuapgWAtRzrmM+GLTWgg8og8IOyt6ZvFLTvQ6TdIU4jAZ9qJLorPPx8ToMIzve9bunjAzUZTwZAuejvlIVhEDGHZ43P+c2vnuH0s6xLjGN5IxE0xoW1w0CkEhDEzZIIIKKKJQkS+HFVRzrtPvD4ASgRgCszCJ7egCW+IZ1AZrFQIbETEL8gYz6s0SYtQwYi6Qsmdq1IQVCNcDQEDNHPNnw9vKmss525+DcQrAWHAQARzWHlAGPJFvL0qtVnM2mDSOxfDb56lUUmGI9SmNfCBxBRJtxwA+2eJCOmpSpXLFbYv8diZyMpTv2LEbyMNcTJr20IxsYzUrvRbyu5dvYHUZsRs8gfCLXUEVYi8a2a9PXF+ZtLPx0ZOLRblX8XTa0QJJSoa+VKRIKD5RCmFKYOIiBoFAUCXYIXCCWZKNExSIoiMUmCpS01EkRLAsoE0NCxCz8oQK0iCYNZrgS0sWA4zJgpKMgxYZxIN0k6OoboxHmMgmKyNy3rUrA2BW11g0yU50ArBdUNYm7rW6l+FmQDmsfUcr8Nxpt6ME1pzmPW2YuvyqQA1FEqGKaOFgPS4YwF0qjqJ96aNghQyxO4ETMPCpx6cPhE1xsRksh7qapVjAG7QQVa6blYCqhJolWKylASeNpfutZRkWEfehrAM1hps1M6VN9y+8pnOeOL3eSrvGKkr3kEDbExtsYADtYMAhLoFzWdZo6F3T89cLurlkYDQ8iWVgjINJHQatNc/BZZPPYhX7J3dX5zJTnZ1pJIV4y+k2MF25BTUhIvz2okmED6ax7KgYdJtMkMMjHiBpMVmJIippQbqyHkJreoQDGrZe8QH4qNpIBqEHFpVTrJVwkLCu5ds3+pbccosPAGFjP4J0AB15EXRr4rcAbXmibqr2600yb4dM8VbMHACFOCBZhZIxpWCMkDUZIBUQoKpooWCkAnBzOK5na/LqSSLTATYIaabQCteZkFlqs0bDPpuWAcNiRn6GWSnwrsatNVFIK0+WUGVX3p1UghXmamW9amFzoPHfP2Z3WLhW9ZEaq0DQiqOJyRC17MYwQA84eUDjyR/GOBNpNoO1pV6NwwsBZoAgBWz+M+YS5GC+Su1IEB0A5in0LwPQxXq7joeDPBdd3DzF6z96RTojxR29u8vE3GnO6jAa0MBmCuoxyYl/SDsbSpYIlMINttOUZndGWJ2JgBs8s7bw1GhnALOxFBnZayRRjt4bSvH+Ma9WNZSaKBoUDtDEQNIMt5XAZJIvEFZSahWUgL7ADIBAjZYJVAK8NHljSCRbLZdxbuCkFfrZVirL+GkBWYaJFCoglTaEWtiguhCVZNjj+c9eMUMbOVJQmcHOmKmRIKboAMkAbohUflNANgubKuhTXDGSlSKY0PetmdL+7bQoIJCVRY+osfasgH1NADQYBBoYd+dccoSIhapDyYkRkhkYGAZDWCMlJReDHnRJZKAxUYiJmPGYriVoGAkdW2QI785BQQakRBFiFEknMOMGpw8jj8a7sLaWrGrZ5gDnB2Ys6AFHfczh5BvVw8R6n1P4QHEbDeIf/i7kinChIP/Mpng=" + kernels = Kernel( + bz2.decompress(base64.b64decode(quantization_code)), + [ + "int4_to_fp16", + "fp16_to_int4", + "int8_to_fp16", + "fp16_to_int8", + "int4_to_bf16", + "bf16_to_int4", + "int8_to_bf16", + "bf16_to_int8", + ], + ) +except Exception as exception: + kernels = None + logger.warning("Failed to load kernels:" + str(exception)) + +def quant4(weight: torch.Tensor, scale: torch.Tensor): + stream = torch.cuda.current_stream() + num_row = weight.size(0) + num_chan_fp16 = weight.size(1) + # 4bit + num_chan_int = num_chan_fp16 // 8 + qweight = torch.zeros((num_row, num_chan_int), dtype=torch.int32, device=weight.device) + intweight = torch.empty(num_row, num_chan_fp16, dtype = torch.int32) + intweight = torch.clip(torch.round(weight.to(scale.dtype) / scale[:, None]),-16, 15).to(dtype=torch.int32) + + for j in range(num_chan_int): + qweight[:, j] = ((intweight[:, j*8+7] & 0x0f) << 28) \ + | ((intweight[:, j*8+6] & 0x0f) << 24) \ + | ((intweight[:, j*8+5] & 0x0f) << 20) \ + | ((intweight[:, j*8+4] & 0x0f) << 16) \ + | ((intweight[:, j*8+3] & 0x0f) << 12) \ + | ((intweight[:, j*8+2] & 0x0f) << 8) \ + | ((intweight[:, j*8+1] & 0x0f) << 4) \ + | ((intweight[:, j*8] & 0x0f)) + return qweight + +def dequant4(qweight: torch.Tensor, scale: torch.Tensor, input: torch.Tensor): + stream = torch.cuda.current_stream() + num_row = qweight.size(0) + num_chan_int = qweight.size(1) + # 4bit + num_chan_fp16 = num_chan_int * 8 + + out = torch.empty((num_row, num_chan_fp16), dtype=input.dtype, device=qweight.device) + + blockDim = (128, 1, 1) + gridDim = ((num_chan_int + blockDim[0] - 1) // blockDim[0], num_row, 1) + if input.dtype == torch.bfloat16: + kernels.int4_to_bf16( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(out.data_ptr()), ctypes.c_void_p(qweight.data_ptr()), + ctypes.c_void_p(scale.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan_int), ctypes.c_int32(num_chan_fp16)], + ) + elif input.dtype == torch.float16: + kernels.int4_to_fp16( + gridDim, + blockDim, + 0, + stream, + [ctypes.c_void_p(out.data_ptr()), ctypes.c_void_p(qweight.data_ptr()), + ctypes.c_void_p(scale.data_ptr()), ctypes.c_int32(num_row), ctypes.c_int32(num_chan_int), ctypes.c_int32(num_chan_fp16)], + ) + return out + +class QLinear(torch.nn.Module): + def __init__(self, bits: int, weight: torch.Tensor, bias=None): + super().__init__() + self.quant_bits = bits + self.scale = weight.abs().max(dim=-1).values / ((2 ** (bits - 1)) - 1) + self.scale = self.scale.to(torch.float32) + if self.quant_bits == 4: + self.weight = quant4(weight, self.scale) + elif self.quant_bits == 8: + self.weight = torch.round(weight.to(self.scale.dtype) / self.scale[:, None]).to(torch.int8) + if self.quant_bits == 8: + self.weight = self.weight.T + self.bias = None + + def forward(self, input): + if self.quant_bits == 4: + assert(input.dtype == torch.bfloat16 or input.dtype == torch.float16) + + if self.weight.device != input.device: + self.weight = self.weight.to(input.device) + self.scale = self.scale.to(input.device) + + if self.quant_bits == 4: + self.scale = self.scale.to(input.dtype) + rweight = dequant4(self.weight, self.scale, input).T + output = torch.matmul(input, rweight) + elif self.quant_bits == 8: + rweight = self.weight.to(input.dtype) * self.scale.to(input.dtype) + output = torch.matmul(input, rweight) + if self.bias is not None: + output = output + self.bias + return output diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000..4835406 --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,30 @@ +{ + "bos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "pad_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "unk_token": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + } +} diff --git a/tokenization_baichuan.py b/tokenization_baichuan.py new file mode 100644 index 0000000..0bec645 --- /dev/null +++ b/tokenization_baichuan.py @@ -0,0 +1,233 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": {}, + "tokenizer_file": {}, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + + +class BaichuanTokenizer(PreTrainedTokenizer): + """ + Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + diff --git a/tokenizer.model b/tokenizer.model new file mode 100644 index 0000000..0d8c8bb --- /dev/null +++ b/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f7d1ab69d25c74644af5c5e4dcd1cc6e96d33783dbd257b6bdea55b643c72813 +size 1136765 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000..e441aae --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,46 @@ +{ + "add_bos_token": false, + "add_eos_token": false, + "auto_map": { + "AutoTokenizer": [ + "tokenization_baichuan.BaichuanTokenizer", + null + ] + }, + "bos_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "clean_up_tokenization_spaces": false, + "eos_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "model_max_length": 4096, + "pad_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + }, + "sp_model_kwargs": {}, + "tokenizer_class": "BaichuanTokenizer", + "unk_token": { + "__type": "AddedToken", + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": true + } +}