87 lines
3.1 KiB
Markdown
87 lines
3.1 KiB
Markdown
|
|
###### Clone with HTTP
|
|||
|
|
```
|
|||
|
|
git clone https://www.modelscope.cn/Nobitaxi/InternLM2-chat-7B-SQL.git
|
|||
|
|
```
|
|||
|
|
# \<InternLM2-chat-7B-SQL>介绍
|
|||
|
|
本模型是基于InternLM-chat-7B,在19,5297条数据上进行3个epoch的FineTune训练出来的关于Text-to-SQL领域的模型。
|
|||
|
|
|
|||
|
|
## 模型描述
|
|||
|
|
|
|||
|
|
本模型根据自然语言问句及数据库表结构生成对应的SQL。
|
|||
|
|
|
|||
|
|
## 训练数据介绍
|
|||
|
|
|
|||
|
|
使用DB-GPT处理并在Hugging Face开源的数据集,经过筛除掉多轮对话数据以及整理格式后得到19,5297条数据。
|
|||
|
|
|
|||
|
|
处理后的格式如下:
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
[
|
|||
|
|
{
|
|||
|
|
"question": "which states border arizona",
|
|||
|
|
"context": "CREATE TABLE mountain (mountain_name, mountain_altitude, state_name, country_name); CREATE TABLE city (city_name, state_name, population, country_name); CREATE TABLE road (road_name, state_name); CREATE TABLE border_info (state_name, border); CREATE TABLE river (river_name, length, traverse, country_name); CREATE TABLE state (state_name, capital, population, area, country_name, density); CREATE TABLE highlow (state_name, highest_point, highest_elevation, lowest_point, lowest_elevation); CREATE TABLE lake (lake_name, area, state_name, country_name)",
|
|||
|
|
"answer": "SELECT border FROM border_info WHERE state_name = 'arizona'"
|
|||
|
|
},
|
|||
|
|
...
|
|||
|
|
{}
|
|||
|
|
]
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## 模型训练流程
|
|||
|
|
|
|||
|
|
本项目使用xtuner0.1.15, 在internlm2-chat-7b上进行微调, [模型地址](https://www.modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-chat-7b/summary)
|
|||
|
|
|
|||
|
|
1. 使用xtuner train进行训练
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
xtuner train ${YOUR_CONFIG} --deepspeed deepseed_zero2
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
2. 将训练后得到的.pth模型转换为HuggingFace模型: **即:生成Adapter文件**
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
export MKL_SERVICE_FORCE_INTEL=1
|
|||
|
|
xtuner convert pth_to_hf ${YOUR_CONFIG} ${PTH} ${ADAPTER_PATH}
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
3. 将HuggingFace adapter合并到训练的基座模型中
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
xtuner convert merge ${BASE_LLM_PATH} ${ADAPTER_PATH} ${SAVE_PATH} --max-shard-size 2GB
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
## 如何使用
|
|||
|
|
编写一个cli_demo.py脚本用来简单使用
|
|||
|
|
```
|
|||
|
|
import torch
|
|||
|
|
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
|||
|
|
|
|||
|
|
|
|||
|
|
model_name_or_path = "Nobitaxi/InternLM2-chat-7B-SQL"
|
|||
|
|
|
|||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
|
|||
|
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map='auto')
|
|||
|
|
model = model.eval()
|
|||
|
|
|
|||
|
|
system_prompt = """If you are an expert in SQL, please generate a good SQL Query for Question based on the CREATE TABLE statement."""
|
|||
|
|
|
|||
|
|
messages = [(system_prompt, '')]
|
|||
|
|
|
|||
|
|
print("=============Welcome to InternLM2-chat-7b-sql chatbot, type 'exit' to exit.=============")
|
|||
|
|
|
|||
|
|
while True:
|
|||
|
|
input_text = input("User >>> ")
|
|||
|
|
input_text.replace(' ', '')
|
|||
|
|
if input_text == "exit":
|
|||
|
|
break
|
|||
|
|
response, history = model.chat(tokenizer, input_text, history=messages)
|
|||
|
|
messages.append((input_text, response))
|
|||
|
|
print(f"robot >>> {response}")
|
|||
|
|
```
|
|||
|
|
**注意:输入中需要包含问句及数据库模式,训练模型时,数据集中的数据库模式用建表语句表示;体验时可以额外添加更多上下文信息以提升模型生成SQL的准确率**
|
|||
|
|
|
|||
|
|
---
|
|||
|
|
|
|||
|
|
**其他内容, 后续补充~**
|
|||
|
|
|