138 lines
3.6 KiB
Markdown
138 lines
3.6 KiB
Markdown
|
|
---
|
||
|
|
license: MIT
|
||
|
|
|
||
|
|
frameworks:
|
||
|
|
- Pytorch
|
||
|
|
|
||
|
|
tasks:
|
||
|
|
- text-generation
|
||
|
|
|
||
|
|
model-type:
|
||
|
|
- qwen2
|
||
|
|
|
||
|
|
domain:
|
||
|
|
- nlp
|
||
|
|
|
||
|
|
language:
|
||
|
|
- zh, en
|
||
|
|
|
||
|
|
tags:
|
||
|
|
- LoRA
|
||
|
|
- GRPO
|
||
|
|
- NL2SQL
|
||
|
|
|
||
|
|
base_model:
|
||
|
|
- Qwen/Qwen2.5-Coder-3B-Instruct
|
||
|
|
|
||
|
|
base_model_relation:
|
||
|
|
- finetune
|
||
|
|
|
||
|
|
datasets:
|
||
|
|
- ruohuaw/sql-cot
|
||
|
|
---
|
||
|
|
notes
|
||
|
|
- we recommand turning off beam search and setting temperature = 0.1 for best performance
|
||
|
|
|
||
|
|
dataset:
|
||
|
|
- sql-cot: https://modelscope.cn/datasets/ruohuaw/sql-cot
|
||
|
|
|
||
|
|
training logs:
|
||
|
|
- https://wandb.ai/wangruohua1999econ-none/huggingface/runs/ygyz68bd?nw=nwuserwangruohua1999econ
|
||
|
|
|
||
|
|
usage:
|
||
|
|
```python
|
||
|
|
import torch
|
||
|
|
from modelscope import snapshot_download
|
||
|
|
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
||
|
|
model_dir = snapshot_download('ruohuaw/deepquery-3b-sft')
|
||
|
|
|
||
|
|
# 系统提示设置
|
||
|
|
sys_prompt = """You are DeepQuery, a data science expert. Below, you are presented with a database schema, a question and a hint.Your task is to read the schema with annotations of the columns, understand the question and the hint, and generate a valid SQL query to answer the question. You should reason step by step, and includes your reasonings between <think> and </think>."""
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
||
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
|
model = model.to(device)
|
||
|
|
model.eval()
|
||
|
|
def generate_response(input_prompt, temperature = 0.1, max_length=512*4, model = model, tokenizer = tokenizer):
|
||
|
|
inputs = tokenizer(input_prompt, return_tensors="pt").to(device)
|
||
|
|
with torch.no_grad():
|
||
|
|
outputs = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_length = max_length,
|
||
|
|
temperature = temperature
|
||
|
|
)
|
||
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||
|
|
return response
|
||
|
|
|
||
|
|
query = """Database Schema
|
||
|
|
###
|
||
|
|
CREATE TABLE Country
|
||
|
|
(
|
||
|
|
CountryCode TEXT not null primary key,
|
||
|
|
LongName TEXT, -- `Long Name` description: long or full name of countries
|
||
|
|
);
|
||
|
|
|
||
|
|
CREATE TABLE Series
|
||
|
|
(
|
||
|
|
SeriesCode TEXT not null primary key,
|
||
|
|
);
|
||
|
|
|
||
|
|
CREATE TABLE SeriesNotes
|
||
|
|
(
|
||
|
|
Seriescode TEXT not null, -- `Series code` description: code identifying the series
|
||
|
|
Year TEXT not null, --
|
||
|
|
Description TEXT, --
|
||
|
|
primary key (Seriescode, Year),
|
||
|
|
foreign key (Seriescode) references Series(SeriesCode),
|
||
|
|
);
|
||
|
|
|
||
|
|
CREATE TABLE CountryNotes
|
||
|
|
(
|
||
|
|
Countrycode TEXT NOT NULL, --
|
||
|
|
Seriescode TEXT NOT NULL, -- `Series code` description: Series code of countries
|
||
|
|
Description TEXT, --
|
||
|
|
primary key (Countrycode, Seriescode),
|
||
|
|
FOREIGN KEY (Seriescode) REFERENCES Series(SeriesCode),
|
||
|
|
FOREIGN KEY (Countrycode) REFERENCES Country(CountryCode),
|
||
|
|
);
|
||
|
|
|
||
|
|
###
|
||
|
|
Question:
|
||
|
|
Please list the full names of any three countries that have their series code with a description of UN Energy Statistics (2014).
|
||
|
|
|
||
|
|
Hint:
|
||
|
|
full name refers to longname
|
||
|
|
"""
|
||
|
|
input_prompt = tokenizer.apply_chat_template(
|
||
|
|
[
|
||
|
|
{"role": "system", "content": SYS},
|
||
|
|
{"role": "user", "content": QUERY},
|
||
|
|
|
||
|
|
],
|
||
|
|
tokenize=False,
|
||
|
|
add_generation_prompt=True
|
||
|
|
)
|
||
|
|
TRUE_ANSWER="""SELECT DISTINCT T2.LongName FROM CountryNotes AS T1 INNER JOIN Country AS T2 ON T1.Countrycode = T2.CountryCode WHERE T1.Description = 'Sources: UN Energy Statistics (2014)' LIMIT 3"""
|
||
|
|
|
||
|
|
RESPONSE = generate_response(input_prompt)
|
||
|
|
print(f"Response: \n{RESPONSE}\nTrue answer: \n{TRUE_ANSWER}")
|
||
|
|
|
||
|
|
```
|
||
|
|
|
||
|
|
SDK下载
|
||
|
|
```bash
|
||
|
|
#安装ModelScope
|
||
|
|
pip install modelscope
|
||
|
|
```
|
||
|
|
|
||
|
|
```python
|
||
|
|
#SDK模型下载
|
||
|
|
from modelscope import snapshot_download
|
||
|
|
model_dir = snapshot_download('ruohuaw/deepquery-3b-sft')
|
||
|
|
```
|
||
|
|
Git下载
|
||
|
|
```
|
||
|
|
#Git模型下载
|
||
|
|
git clone https://www.modelscope.cn/ruohuaw/deepquery-3b-sft.git
|
||
|
|
```
|