初始化项目,由ModelHub XC社区提供模型
Model: ruohuaw/deepquery-3b-sft Source: Original Platform
This commit is contained in:
138
README.md
Normal file
138
README.md
Normal file
@@ -0,0 +1,138 @@
|
||||
---
|
||||
license: MIT
|
||||
|
||||
frameworks:
|
||||
- Pytorch
|
||||
|
||||
tasks:
|
||||
- text-generation
|
||||
|
||||
model-type:
|
||||
- qwen2
|
||||
|
||||
domain:
|
||||
- nlp
|
||||
|
||||
language:
|
||||
- zh, en
|
||||
|
||||
tags:
|
||||
- LoRA
|
||||
- GRPO
|
||||
- NL2SQL
|
||||
|
||||
base_model:
|
||||
- Qwen/Qwen2.5-Coder-3B-Instruct
|
||||
|
||||
base_model_relation:
|
||||
- finetune
|
||||
|
||||
datasets:
|
||||
- ruohuaw/sql-cot
|
||||
---
|
||||
notes
|
||||
- we recommand turning off beam search and setting temperature = 0.1 for best performance
|
||||
|
||||
dataset:
|
||||
- sql-cot: https://modelscope.cn/datasets/ruohuaw/sql-cot
|
||||
|
||||
training logs:
|
||||
- https://wandb.ai/wangruohua1999econ-none/huggingface/runs/ygyz68bd?nw=nwuserwangruohua1999econ
|
||||
|
||||
usage:
|
||||
```python
|
||||
import torch
|
||||
from modelscope import snapshot_download
|
||||
from modelscope import AutoTokenizer, AutoModelForCausalLM
|
||||
model_dir = snapshot_download('ruohuaw/deepquery-3b-sft')
|
||||
|
||||
# 系统提示设置
|
||||
sys_prompt = """You are DeepQuery, a data science expert. Below, you are presented with a database schema, a question and a hint.Your task is to read the schema with annotations of the columns, understand the question and the hint, and generate a valid SQL query to answer the question. You should reason step by step, and includes your reasonings between <think> and </think>."""
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_dir)
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
def generate_response(input_prompt, temperature = 0.1, max_length=512*4, model = model, tokenizer = tokenizer):
|
||||
inputs = tokenizer(input_prompt, return_tensors="pt").to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_length = max_length,
|
||||
temperature = temperature
|
||||
)
|
||||
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
||||
return response
|
||||
|
||||
query = """Database Schema
|
||||
###
|
||||
CREATE TABLE Country
|
||||
(
|
||||
CountryCode TEXT not null primary key,
|
||||
LongName TEXT, -- `Long Name` description: long or full name of countries
|
||||
);
|
||||
|
||||
CREATE TABLE Series
|
||||
(
|
||||
SeriesCode TEXT not null primary key,
|
||||
);
|
||||
|
||||
CREATE TABLE SeriesNotes
|
||||
(
|
||||
Seriescode TEXT not null, -- `Series code` description: code identifying the series
|
||||
Year TEXT not null, --
|
||||
Description TEXT, --
|
||||
primary key (Seriescode, Year),
|
||||
foreign key (Seriescode) references Series(SeriesCode),
|
||||
);
|
||||
|
||||
CREATE TABLE CountryNotes
|
||||
(
|
||||
Countrycode TEXT NOT NULL, --
|
||||
Seriescode TEXT NOT NULL, -- `Series code` description: Series code of countries
|
||||
Description TEXT, --
|
||||
primary key (Countrycode, Seriescode),
|
||||
FOREIGN KEY (Seriescode) REFERENCES Series(SeriesCode),
|
||||
FOREIGN KEY (Countrycode) REFERENCES Country(CountryCode),
|
||||
);
|
||||
|
||||
###
|
||||
Question:
|
||||
Please list the full names of any three countries that have their series code with a description of UN Energy Statistics (2014).
|
||||
|
||||
Hint:
|
||||
full name refers to longname
|
||||
"""
|
||||
input_prompt = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": SYS},
|
||||
{"role": "user", "content": QUERY},
|
||||
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
TRUE_ANSWER="""SELECT DISTINCT T2.LongName FROM CountryNotes AS T1 INNER JOIN Country AS T2 ON T1.Countrycode = T2.CountryCode WHERE T1.Description = 'Sources: UN Energy Statistics (2014)' LIMIT 3"""
|
||||
|
||||
RESPONSE = generate_response(input_prompt)
|
||||
print(f"Response: \n{RESPONSE}\nTrue answer: \n{TRUE_ANSWER}")
|
||||
|
||||
```
|
||||
|
||||
SDK下载
|
||||
```bash
|
||||
#安装ModelScope
|
||||
pip install modelscope
|
||||
```
|
||||
|
||||
```python
|
||||
#SDK模型下载
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download('ruohuaw/deepquery-3b-sft')
|
||||
```
|
||||
Git下载
|
||||
```
|
||||
#Git模型下载
|
||||
git clone https://www.modelscope.cn/ruohuaw/deepquery-3b-sft.git
|
||||
```
|
||||
Reference in New Issue
Block a user