初始化项目,由ModelHub XC社区提供模型
Model: OpenBMB/MiniCPM4-Survey Source: Original Platform
This commit is contained in:
35
.gitattributes
vendored
Normal file
35
.gitattributes
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
200
README.md
Normal file
200
README.md
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
---
|
||||||
|
language:
|
||||||
|
- zh
|
||||||
|
- en
|
||||||
|
library_name: transformers
|
||||||
|
license: apache-2.0
|
||||||
|
pipeline_tag: text-generation
|
||||||
|
---
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://github.com/OpenBMB/MiniCPM/blob/main/assets/minicpm_logo.png?raw=true" width="500em" ></img>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p align="center">
|
||||||
|
<a href="https://github.com/OpenBMB/MiniCPM/\" target="_blank">GitHub Repo</a> |
|
||||||
|
<a href="https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf" target="_blank">Technical Report</a> |
|
||||||
|
<a href="https://huggingface.co/papers/2506.07900" target="_blank">Paper</a>
|
||||||
|
</p>
|
||||||
|
<p align="center">
|
||||||
|
👋 Join us on <a href="https://discord.gg/3cGQn9b3YM" target="_blank">Discord</a> and <a href="https://github.com/OpenBMB/MiniCPM/blob/main/assets/wechat.jpg" target="_blank">WeChat</a>
|
||||||
|
</p>
|
||||||
|
|
||||||
|
This repository contains the model described in the paper [MiniCPM4: Ultra-Efficient LLMs on End Devices](https://huggingface.co/papers/2506.07900).
|
||||||
|
|
||||||
|
## What's New
|
||||||
|
|
||||||
|
* [2025-06-05] 🚀🚀🚀 We have open-sourced **MiniCPM4-Survey**, a model built upon MiniCPM4-8B that is capable of generating trustworthy, long-form survey papers while maintaining competitive performance relative to significantly larger models.
|
||||||
|
|
||||||
|
## MiniCPM4 Series
|
||||||
|
MiniCPM4 series are highly efficient large language models (LLMs) designed explicitly for end-side devices, which achieves this efficiency through systematic innovation in four key dimensions: model architecture, training data, training algorithms, and inference systems.
|
||||||
|
- [MiniCPM4-8B](https://huggingface.co/openbmb/MiniCPM4-8B): The flagship of MiniCPM4, with 8B parameters, trained on 8T tokens.
|
||||||
|
- [MiniCPM4-0.5B](https://huggingface.co/openbmb/MiniCPM4-0.5B): The small version of MiniCPM4, with 0.5B parameters, trained on 1T tokens.
|
||||||
|
- [MiniCPM4-8B-Eagle-FRSpec](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec): Eagle head for FRSpec, accelerating speculative inference for MiniCPM4-8B.
|
||||||
|
- [MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu): Eagle head trained with QAT for FRSpec, efficiently integrate speculation and quantization to achieve ultra acceleration for MiniCPM4-8B.
|
||||||
|
- [MiniCPM4-8B-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-Eagle-vLLM): Eagle head in vLLM format, accelerating speculative inference for MiniCPM4-8B.
|
||||||
|
- [MiniCPM4-8B-marlin-Eagle-vLLM](https://huggingface.co/openbmb/MiniCPM4-8B-marlin-Eagle-vLLM): Quantized Eagle head for vLLM format, accelerating speculative inference for MiniCPM4-8B.
|
||||||
|
- [BitCPM4-0.5B](https://huggingface.co/openbmb/BitCPM4-0.5B): Extreme ternary quantization applied to MiniCPM4-0.5B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
|
||||||
|
- [BitCPM4-1B](https://huggingface.co/openbmb/BitCPM4-1B): Extreme ternary quantization applied to MiniCPM3-1B compresses model parameters into ternary values, achieving a 90% reduction in bit width.
|
||||||
|
- [MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey): Based on MiniCPM4-8B, accepts users' quiries as input and autonomously generate trustworthy, long-form survey papers. (**<-- you are here**)
|
||||||
|
- [MiniCPM4-MCP](https://huggingface.co/openbmb/MiniCPM4-MCP): Based on MiniCPM4-8B, accepts users' queries and available MCP tools as input and autonomously calls relevant MCP tools to satisfy users' requirements.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
**MiniCPM4-Survey** is an open-source LLM agent model jointly developed by [THUNLP](https://nlp.csai.tsinghua.edu.cn), Renmin University of China and [ModelBest](https://modelbest.cn/en). Built on [MiniCPM4](https://github.com/OpenBMB/MiniCPM4) with 8 billion parameters, it accepts users' quiries as input and autonomously generate trustworthy, long-form survey papers.
|
||||||
|
|
||||||
|
Key features include:
|
||||||
|
|
||||||
|
- **Plan-Retrieve-Write Survey Generation Framework** — We propose a multi-agent generation framework, which operates through three core stages: planning (defining the overall structure of the survey), retrieval (generating appropriate retrieval keywords), and writing (synthesizing the retrieved information to generate coherent section-level content).
|
||||||
|
|
||||||
|
- **High-Quality Dataset Construction** — We gather and process lots of expert-written survey papers to construct a high-quality training dataset. Meanwhile, we collect a large number of research papers to build a retrieval database.
|
||||||
|
|
||||||
|
- **Multi-Aspect Reward Design** — We carefully design a reward system with three aspects (structure, content, and citations) to evaluate the quality of the surveys, which is used as the reward function in the RL training stage.
|
||||||
|
|
||||||
|
- **Multi-Step RL Training Strategy** — We propose a *Context Manager* to ensure retention of essential information while facilitating efficient reasoning, and we construct *Parallel Environment* to maintain efficient RL training cycles.
|
||||||
|
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
### Download the model
|
||||||
|
|
||||||
|
Download [MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey) from Hugging Face and place it in `model/MiniCPM4-Survey`.
|
||||||
|
We recommend using [MiniCPM-Embedding-Light](https://huggingface.co/openbmb/MiniCPM-Embedding-Light) as the embedding model, which can be downloaded from Hugging Face and placed in `model/MiniCPM-Embedding-Light`.
|
||||||
|
### Prepare the environment
|
||||||
|
|
||||||
|
You can download the [paper data](https://www.kaggle.com/datasets/Cornell-University/arxiv) from Kaggle, then extract it. You can run `python data_process.py` to process the data and generate the retrieval database. Then you can run `python build_index.py` to build the retrieval database.
|
||||||
|
|
||||||
|
```
|
||||||
|
cd ./code
|
||||||
|
curl -L -o ~/Downloads/arxiv.zip\
|
||||||
|
https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
|
||||||
|
unzip ~/Downloads/arxiv.zip -d .
|
||||||
|
mkdir data
|
||||||
|
python ./src/preprocess/data_process.py
|
||||||
|
mkdir index
|
||||||
|
python ./src/preprocess/build_index.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Model Inference
|
||||||
|
|
||||||
|
You can run the following command to build the retrieval environment and start the inference:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./code
|
||||||
|
python ./src/retriever.py
|
||||||
|
bash ./scripts/run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to run with the frontend, you can run the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd ./code
|
||||||
|
python ./src/retriever.py
|
||||||
|
bash ./scripts/run_with_frontend.sh
|
||||||
|
cd frontend/minicpm4-survey
|
||||||
|
npm install
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
Then you can visit `http://localhost:5173` in your browser to use the model.
|
||||||
|
|
||||||
|
## Performance Evaluation
|
||||||
|
|
||||||
|
| Method | Relevance | Coverage | Depth | Novelty | Avg. | Fact Score |
|
||||||
|
|---------------------------------------------|-----------|----------|-------|---------|-------|------------|
|
||||||
|
| Naive RAG (driven by G2FT) | 3.25 | 2.95 | 3.35 | 2.60 | 3.04 | 43.68 |
|
||||||
|
| AutoSurvey (driven by G2FT) | 3.10 | 3.25 | 3.15 | **3.15**| 3.16 | 46.56 |
|
||||||
|
| Webthinker (driven by WTR1-7B) | 3.30 | 3.00 | 2.75 | 2.50 | 2.89 | -- |
|
||||||
|
| Webthinker (driven by QwQ-32B) | 3.40 | 3.30 | 3.30 | 2.50 | 3.13 | -- |
|
||||||
|
| OpenAI Deep Research (driven by GPT-4o) | 3.50 |**3.95** | 3.55 | 3.00 | **3.50** | -- |
|
||||||
|
| MiniCPM4-Survey | 3.45 | 3.70 | **3.85** | 3.00 | **3.50** | **68.73** |
|
||||||
|
| *w/o* RL | **3.55** | 3.35 | 3.30 | 2.25 | 3.11 | 50.24 |
|
||||||
|
|
||||||
|
*Performance comparison of the survey generation systems. "G2FT" stands for Gemini-2.0-Flash-Thinking, and "WTR1-7B" denotes Webthinker-R1-7B. FactScore evaluation was omitted for Webthinker, as it does not include citation functionality, and for OpenAI Deep Research, which does not provide citations when exporting the results.*
|
||||||
|
|
||||||
|
## Statement
|
||||||
|
- As a language model, MiniCPM generates content by learning from a vast amount of text.
|
||||||
|
- However, it does not possess the ability to comprehend or express personal opinions or value judgments.
|
||||||
|
- Any content generated by MiniCPM does not represent the viewpoints or positions of the model developers.
|
||||||
|
- Therefore, when using content generated by MiniCPM, users should take full responsibility for evaluating and verifying it on their own.
|
||||||
|
|
||||||
|
## LICENSE
|
||||||
|
- This repository and MiniCPM models are released under the [Apache-2.0](https://github.com/OpenBMB/MiniCPM/blob/main/LICENSE) License.
|
||||||
|
|
||||||
|
## Citation
|
||||||
|
- Please cite our [paper](https://github.com/OpenBMB/MiniCPM/tree/main/report/MiniCPM_4_Technical_Report.pdf) if you find our work valuable.
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{minicpm4,
|
||||||
|
title={{MiniCPM4}: Ultra-Efficient LLMs on End Devices},
|
||||||
|
author={MiniCPM Team},
|
||||||
|
year={2025}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# 中文
|
||||||
|
## News
|
||||||
|
|
||||||
|
* [2025-06-05] 🚀🚀🚀我们开源了基于MiniCPM4-8B构建的MiniCPM4-Survey,能够生成可信的长篇调查报告,性能比肩更大模型。
|
||||||
|
|
||||||
|
## 概览
|
||||||
|
|
||||||
|
MiniCPM4-Survey是由[THUNLP](https://nlp.csai.tsinghua.edu.cn)、中国人民大学和[ModelBest](https://modelbest.cn)联合开发的开源大语言模型智能体。它基于[MiniCPM4](https://github.com/OpenBMB/MiniCPM4) 80亿参数基座模型,接受用户质量作为输入,自主生成可信的长篇综述论文。
|
||||||
|
|
||||||
|
主要特性包括:
|
||||||
|
- 计划-检索-写作生成框架 — 我们提出了一个多智能体生成框架,包含三个核心阶段:计划(定义综述的整体结构)、检索(生成合适的检索关键词)和写作(利用检索到的信息,生成连贯的段落)。
|
||||||
|
- 高质量数据集构建——我们收集并处理大量人类专家写作的综述论文,构建高质量训练集。同时,我们收集大量研究论文,构建检索数据库。
|
||||||
|
- 多方面奖励设计 — 我们精心设计了包含结构、内容和引用的奖励,用于评估综述的质量,在强化学习训练阶段作奖励函数。
|
||||||
|
- 多步强化学习训练策略 — 我们提出了一个上下文管理器,以确保在促进有效推理的同时保留必要的信息,并构建了并行环境,维持强化学习训练高效。
|
||||||
|
|
||||||
|
|
||||||
|
## 使用
|
||||||
|
|
||||||
|
### 下载模型
|
||||||
|
从 Hugging Face 下载[MiniCPM4-Survey](https://huggingface.co/openbmb/MiniCPM4-Survey)并将其放在model/MiniCPM4-Survey中。
|
||||||
|
我们建议使用[MiniCPM-Embedding-Light](https://huggingface.co/openbmb/MiniCPM-Embedding-Light)作为表征模型,放在model/MiniCPM-Embedding-Light中。
|
||||||
|
|
||||||
|
### 准备环境
|
||||||
|
从 Kaggle 下载论文数据,然后解压。运行`python data_process.py`,处理数据并生成检索数据库。然后运行`python build_index.py`,构建检索数据库。
|
||||||
|
``` bash
|
||||||
|
cd ./code
|
||||||
|
curl -L -o ~/Downloads/arxiv.zip\
|
||||||
|
https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
|
||||||
|
unzip ~/Downloads/arxiv.zip -d .
|
||||||
|
mkdir data
|
||||||
|
python ./src/preprocess/data_process.py
|
||||||
|
mkdir index
|
||||||
|
python ./src/preprocess/build_index.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 模型推理
|
||||||
|
运行以下命令来构建检索环境并开始推理:
|
||||||
|
``` bash
|
||||||
|
cd ./code
|
||||||
|
python ./src/retriever.py
|
||||||
|
bash ./scripts/run.sh
|
||||||
|
```
|
||||||
|
如果您想使用前端运行,可以运行以下命令:
|
||||||
|
``` bash
|
||||||
|
cd ./code
|
||||||
|
python ./src/retriever.py
|
||||||
|
bash ./scripts/run_with_frontend.sh
|
||||||
|
cd frontend/minicpm4-survey
|
||||||
|
npm install
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
然后你可以在浏览器中访问`http://localhost:5173`使用。
|
||||||
|
|
||||||
|
## 性能
|
||||||
|
|
||||||
|
| Method | Relevance | Coverage | Depth | Novelty | Avg. | Fact Score |
|
||||||
|
|---------------------------------------------|-----------|----------|-------|---------|-------|------------|
|
||||||
|
| Naive RAG (driven by G2FT) | 3.25 | 2.95 | 3.35 | 2.60 | 3.04 | 43.68 |
|
||||||
|
| AutoSurvey (driven by G2FT) | 3.10 | 3.25 | 3.15 | **3.15**| 3.16 | 46.56 |
|
||||||
|
| Webthinker (driven by WTR1-7B) | 3.30 | 3.00 | 2.75 | 2.50 | 2.89 | -- |
|
||||||
|
| Webthinker (driven by QwQ-32B) | 3.40 | 3.30 | 3.30 | 2.50 | 3.13 | -- |
|
||||||
|
| OpenAI Deep Research (driven by GPT-4o) | 3.50 |**3.95** | 3.55 | 3.00 | **3.50** | -- |
|
||||||
|
| MiniCPM4-Survey | 3.45 | 3.70 | **3.85** | 3.00 | **3.50** | **68.73** |
|
||||||
|
| *w/o* RL | **3.55** | 3.35 | 3.30 | 2.25 | 3.11 | 50.24 |
|
||||||
|
|
||||||
|
*GPT-4o对综述生成系统的性能比较。“G2FT”代表Gemini-2.0-Flash-Thinking,“WTR1-7B”代表Webthinker-R1-7B。由于Webthinker不包括引用功能,OpenAI Deep Research在导出结果时不提供引用,因此省略了对它们的FactScore评估。我们的技术报告中包含评测的详细信息。*
|
||||||
10
added_tokens.json
Normal file
10
added_tokens.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"<|execute_end|>": 73444,
|
||||||
|
"<|execute_start|>": 73443,
|
||||||
|
"<|fim_middle|>": 73446,
|
||||||
|
"<|fim_prefix|>": 73445,
|
||||||
|
"<|fim_suffix|>": 73447,
|
||||||
|
"<|im_end|>": 73440,
|
||||||
|
"<|im_start|>": 73441,
|
||||||
|
"<|tool_call|>": 73442
|
||||||
|
}
|
||||||
24
code/frontend/minicpm4-survey/.gitignore
vendored
Normal file
24
code/frontend/minicpm4-survey/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Logs
|
||||||
|
logs
|
||||||
|
*.log
|
||||||
|
npm-debug.log*
|
||||||
|
yarn-debug.log*
|
||||||
|
yarn-error.log*
|
||||||
|
pnpm-debug.log*
|
||||||
|
lerna-debug.log*
|
||||||
|
|
||||||
|
node_modules
|
||||||
|
dist
|
||||||
|
dist-ssr
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor directories and files
|
||||||
|
.vscode/*
|
||||||
|
!.vscode/extensions.json
|
||||||
|
.idea
|
||||||
|
.DS_Store
|
||||||
|
*.suo
|
||||||
|
*.ntvs*
|
||||||
|
*.njsproj
|
||||||
|
*.sln
|
||||||
|
*.sw?
|
||||||
33
code/frontend/minicpm4-survey/eslint.config.js
Normal file
33
code/frontend/minicpm4-survey/eslint.config.js
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import js from '@eslint/js'
|
||||||
|
import globals from 'globals'
|
||||||
|
import reactHooks from 'eslint-plugin-react-hooks'
|
||||||
|
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||||
|
|
||||||
|
export default [
|
||||||
|
{ ignores: ['dist'] },
|
||||||
|
{
|
||||||
|
files: ['**/*.{js,jsx}'],
|
||||||
|
languageOptions: {
|
||||||
|
ecmaVersion: 2020,
|
||||||
|
globals: globals.browser,
|
||||||
|
parserOptions: {
|
||||||
|
ecmaVersion: 'latest',
|
||||||
|
ecmaFeatures: { jsx: true },
|
||||||
|
sourceType: 'module',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
plugins: {
|
||||||
|
'react-hooks': reactHooks,
|
||||||
|
'react-refresh': reactRefresh,
|
||||||
|
},
|
||||||
|
rules: {
|
||||||
|
...js.configs.recommended.rules,
|
||||||
|
...reactHooks.configs.recommended.rules,
|
||||||
|
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
||||||
|
'react-refresh/only-export-components': [
|
||||||
|
'warn',
|
||||||
|
{ allowConstantExport: true },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
13
code/frontend/minicpm4-survey/index.html
Normal file
13
code/frontend/minicpm4-survey/index.html
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<link rel="icon" type="image/png" href="/openbmb.svg" />
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
|
<title>MiniCPM4-Survey</title>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/main.jsx"></script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
2822
code/frontend/minicpm4-survey/package-lock.json
generated
Normal file
2822
code/frontend/minicpm4-survey/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
29
code/frontend/minicpm4-survey/package.json
Normal file
29
code/frontend/minicpm4-survey/package.json
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
{
|
||||||
|
"name": "minicpm4-survey",
|
||||||
|
"private": true,
|
||||||
|
"version": "0.0.0",
|
||||||
|
"type": "module",
|
||||||
|
"scripts": {
|
||||||
|
"dev": "vite",
|
||||||
|
"build": "vite build",
|
||||||
|
"lint": "eslint .",
|
||||||
|
"preview": "vite preview"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"dompurify": "^3.2.6",
|
||||||
|
"marked": "^15.0.12",
|
||||||
|
"react": "^19.1.0",
|
||||||
|
"react-dom": "^19.1.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@eslint/js": "^9.25.0",
|
||||||
|
"@types/react": "^19.1.2",
|
||||||
|
"@types/react-dom": "^19.1.2",
|
||||||
|
"@vitejs/plugin-react": "^4.4.1",
|
||||||
|
"eslint": "^9.25.0",
|
||||||
|
"eslint-plugin-react-hooks": "^5.2.0",
|
||||||
|
"eslint-plugin-react-refresh": "^0.4.19",
|
||||||
|
"globals": "^16.0.0",
|
||||||
|
"vite": "^6.3.5"
|
||||||
|
}
|
||||||
|
}
|
||||||
BIN
code/frontend/minicpm4-survey/public/openbmb.png
Normal file
BIN
code/frontend/minicpm4-survey/public/openbmb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 629 B |
318
code/frontend/minicpm4-survey/src/App.css
Normal file
318
code/frontend/minicpm4-survey/src/App.css
Normal file
@@ -0,0 +1,318 @@
|
|||||||
|
:root {
|
||||||
|
--neon-blue: #00f3ff;
|
||||||
|
--neon-purple: #ffffff;
|
||||||
|
--dark-panel: rgba(10, 10, 30, 0.95);
|
||||||
|
--glass-color: rgba(255, 255, 255, 0.05);
|
||||||
|
}
|
||||||
|
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
font-family: 'Orbitron', monospace;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
.cyber-container {
|
||||||
|
position: relative;
|
||||||
|
display: flex;
|
||||||
|
min-height: 100vh;
|
||||||
|
background: radial-gradient(circle at center, #0a0a1f 0%, #000000 100%);
|
||||||
|
padding: 2rem;
|
||||||
|
gap: 2rem;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 全息背景 */
|
||||||
|
.hologram-bg {
|
||||||
|
position: fixed;
|
||||||
|
top: -50%;
|
||||||
|
left: -50%;
|
||||||
|
width: 200%;
|
||||||
|
height: 200%;
|
||||||
|
background: repeating-linear-gradient(
|
||||||
|
45deg,
|
||||||
|
transparent,
|
||||||
|
transparent 5px,
|
||||||
|
rgba(0, 255, 255, 0.05) 5px,
|
||||||
|
rgba(0, 255, 255, 0.05) 10px
|
||||||
|
);
|
||||||
|
animation: scan 20s linear infinite;
|
||||||
|
z-index: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes scan {
|
||||||
|
0% { transform: translateY(-50%) rotate(0deg); }
|
||||||
|
100% { transform: translateY(50%) rotate(360deg); }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 面板样式 */
|
||||||
|
.tech-panel {
|
||||||
|
position: relative;
|
||||||
|
width: 220px;
|
||||||
|
padding: 1.5rem;
|
||||||
|
background: var(--glass-color);
|
||||||
|
backdrop-filter: blur(10px);
|
||||||
|
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 0 20px rgba(0, 255, 255, 0.2);
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tech-panel::before {
|
||||||
|
content: '';
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
height: 2px;
|
||||||
|
background: linear-gradient(90deg, var(--neon-blue), var(--neon-purple), var(--neon-blue));
|
||||||
|
animation: pulse 3s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulse {
|
||||||
|
0%, 100% { opacity: 0.5; }
|
||||||
|
50% { opacity: 1; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 输入框 */
|
||||||
|
.input-wrapper {
|
||||||
|
position: relative;
|
||||||
|
margin-bottom: 1.2rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.neon-input {
|
||||||
|
width: 100%;
|
||||||
|
padding: 0.8rem 1rem;
|
||||||
|
background: rgba(255, 255, 255, 0.03);
|
||||||
|
border: 1px solid rgba(0, 255, 255, 0.3);
|
||||||
|
border-radius: 6px;
|
||||||
|
color: #fff;
|
||||||
|
font-size: 1rem;
|
||||||
|
outline: none;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.neon-input::placeholder {
|
||||||
|
color: rgba(255, 255, 255, 0.3);
|
||||||
|
}
|
||||||
|
|
||||||
|
.neon-input:focus {
|
||||||
|
border-color: var(--neon-blue);
|
||||||
|
box-shadow: 0 0 10px var(--neon-blue);
|
||||||
|
}
|
||||||
|
|
||||||
|
.input-glow {
|
||||||
|
position: absolute;
|
||||||
|
bottom: -5px;
|
||||||
|
left: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 2px;
|
||||||
|
background: linear-gradient(90deg, transparent, var(--neon-purple), transparent);
|
||||||
|
animation: glowPulse 2s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes glowPulse {
|
||||||
|
0%, 100% { opacity: 0; }
|
||||||
|
50% { opacity: 1; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 核心区域 */
|
||||||
|
.core-module {
|
||||||
|
position: relative;
|
||||||
|
flex: 1;
|
||||||
|
z-index: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.quantum-textarea {
|
||||||
|
flex: 1;
|
||||||
|
padding: 2rem;
|
||||||
|
font-size: 1.2rem;
|
||||||
|
background: rgba(0, 0, 20, 0.7);
|
||||||
|
border: 2px solid var(--neon-blue);
|
||||||
|
border-radius: 12px;
|
||||||
|
color: #fff;
|
||||||
|
outline: none;
|
||||||
|
resize: both;
|
||||||
|
min-height: 300px;
|
||||||
|
backdrop-filter: blur(5px);
|
||||||
|
font-family: 'Orbitron', monospace;
|
||||||
|
transition: all 0.3s ease;
|
||||||
|
}
|
||||||
|
|
||||||
|
.quantum-textarea::placeholder {
|
||||||
|
color: rgba(255, 255, 255, 0.2);
|
||||||
|
}
|
||||||
|
|
||||||
|
.quantum-textarea:focus {
|
||||||
|
border-color: var(--neon-purple);
|
||||||
|
box-shadow: 0 0 20px var(--neon-purple);
|
||||||
|
}
|
||||||
|
|
||||||
|
.core-glow {
|
||||||
|
position: absolute;
|
||||||
|
top: 50%;
|
||||||
|
left: 50%;
|
||||||
|
width: 300%;
|
||||||
|
height: 300%;
|
||||||
|
background: radial-gradient(circle, var(--neon-blue) 0%, transparent 70%);
|
||||||
|
opacity: 0.1;
|
||||||
|
transform: translate(-50%, -50%);
|
||||||
|
z-index: 0;
|
||||||
|
pointer-events: none;
|
||||||
|
animation: pulseGlow 5s infinite alternate;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulseGlow {
|
||||||
|
0% { transform: translate(-50%, -50%) scale(1); }
|
||||||
|
100% { transform: translate(-50%, -50%) scale(1.2); }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 数据流动画 */
|
||||||
|
.data-stream {
|
||||||
|
position: absolute;
|
||||||
|
bottom: 1rem;
|
||||||
|
right: 2rem;
|
||||||
|
display: flex;
|
||||||
|
gap: 0.5rem;
|
||||||
|
z-index: 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
.stream-pulse {
|
||||||
|
width: 6px;
|
||||||
|
height: 6px;
|
||||||
|
background: var(--neon-blue);
|
||||||
|
border-radius: 50%;
|
||||||
|
animation: pulseDot 1.5s infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.delay-1 {
|
||||||
|
animation-delay: 0.3s;
|
||||||
|
}
|
||||||
|
|
||||||
|
.delay-2 {
|
||||||
|
animation-delay: 0.6s;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes pulseDot {
|
||||||
|
0% { transform: scale(1); opacity: 1; }
|
||||||
|
70% { transform: scale(1.5); opacity: 0.3; }
|
||||||
|
100% { transform: scale(1); opacity: 1; }
|
||||||
|
}
|
||||||
|
|
||||||
|
/* 响应式设计 */
|
||||||
|
@media (max-width: 1024px) {
|
||||||
|
.cyber-container {
|
||||||
|
flex-direction: column;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tech-panel {
|
||||||
|
width: 100%;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.core-module {
|
||||||
|
min-height: 400px;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.markdown-editor {
|
||||||
|
display: flex;
|
||||||
|
gap: 20px;
|
||||||
|
height: 100%;
|
||||||
|
z-index: 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-input,
|
||||||
|
.markdown-preview {
|
||||||
|
flex: 1;
|
||||||
|
padding: 15px;
|
||||||
|
font-size: 16px;
|
||||||
|
border: 2px solid var(--neon-blue);
|
||||||
|
border-radius: 8px;
|
||||||
|
background: rgba(0, 0, 20, 0.7);
|
||||||
|
color: #fff;
|
||||||
|
font-family: 'Orbitron', monospace;
|
||||||
|
resize: both;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-input {
|
||||||
|
min-height: 300px;
|
||||||
|
outline: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-input:focus {
|
||||||
|
box-shadow: 0 0 10px var(--neon-blue);
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-preview {
|
||||||
|
max-height: 800px; /* 设置最大高度,超出后滚动 */
|
||||||
|
width: 100%;
|
||||||
|
overflow-y: auto; /* 垂直方向溢出时显示滚动条 */
|
||||||
|
padding: 10px;
|
||||||
|
border: 1px solid #333;
|
||||||
|
background-color: #111;
|
||||||
|
color: #eee;
|
||||||
|
font-family: monospace;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Markdown 内容增强样式 */
|
||||||
|
.markdown-preview h1, h2, h3 {
|
||||||
|
color: var(--neon-purple);
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-preview pre {
|
||||||
|
background: #111;
|
||||||
|
padding: 10px;
|
||||||
|
border-radius: 6px;
|
||||||
|
color: #00ffcc;
|
||||||
|
overflow-x: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-preview code {
|
||||||
|
background: rgba(255, 255, 255, 0.1);
|
||||||
|
padding: 2px 4px;
|
||||||
|
border-radius: 4px;
|
||||||
|
color: var(--neon-blue);
|
||||||
|
}
|
||||||
|
|
||||||
|
.core-module {
|
||||||
|
flex: 1;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 1rem;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-toolbar {
|
||||||
|
display: flex;
|
||||||
|
gap: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.neon-button {
|
||||||
|
background-color: #000;
|
||||||
|
color: #0f0;
|
||||||
|
border: 2px solid #0f0;
|
||||||
|
padding: 6px 12px;
|
||||||
|
font-size: 14px;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: all 0.2s ease-in-out;
|
||||||
|
}
|
||||||
|
|
||||||
|
.neon-button:hover {
|
||||||
|
background-color: #0f0;
|
||||||
|
color: #000;
|
||||||
|
}
|
||||||
|
|
||||||
|
.markdown-editor {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 10px;
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
259
code/frontend/minicpm4-survey/src/App.jsx
Normal file
259
code/frontend/minicpm4-survey/src/App.jsx
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
import React, { useState, useEffect, useMemo, useRef } from 'react';
|
||||||
|
import './App.css';
|
||||||
|
import DOMPurify from 'dompurify';
|
||||||
|
import { marked } from 'marked';
|
||||||
|
|
||||||
|
// 自定义 hook:防抖
|
||||||
|
function useDebounce(value, delay) {
|
||||||
|
const [debouncedValue, setDebouncedValue] = useState(value);
|
||||||
|
|
||||||
|
React.useEffect(() => {
|
||||||
|
const handler = setTimeout(() => {
|
||||||
|
setDebouncedValue(value);
|
||||||
|
}, delay);
|
||||||
|
|
||||||
|
return () => clearTimeout(handler);
|
||||||
|
}, [value, delay]);
|
||||||
|
|
||||||
|
return debouncedValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
function MarkdownEditor({ value }) {
|
||||||
|
const containerRef = useRef(null);
|
||||||
|
|
||||||
|
const htmlContent = marked(value || '');
|
||||||
|
const sanitizedHtml = DOMPurify.sanitize(htmlContent);
|
||||||
|
|
||||||
|
const [userScrolled, setUserScrolled] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const container = containerRef.current;
|
||||||
|
if (container && !userScrolled) {
|
||||||
|
requestAnimationFrame(() => {
|
||||||
|
container.scrollTop = container.scrollHeight;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}, [value, userScrolled]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const container = containerRef.current;
|
||||||
|
if (container) {
|
||||||
|
const handleScroll = () => {
|
||||||
|
const atBottom = container.scrollTop + container.clientHeight >= container.scrollHeight - 10;
|
||||||
|
setUserScrolled(!atBottom);
|
||||||
|
};
|
||||||
|
container.addEventListener('scroll', handleScroll);
|
||||||
|
return () => container.removeEventListener('scroll', handleScroll);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// 复制 Markdown 内容
|
||||||
|
const handleCopy = () => {
|
||||||
|
navigator.clipboard.writeText(value || '')
|
||||||
|
.then(() => alert('Markdown 已复制到剪贴板'))
|
||||||
|
.catch(err => console.error('复制失败:', err));
|
||||||
|
};
|
||||||
|
|
||||||
|
// 下载 Markdown 文件
|
||||||
|
const handleDownload = () => {
|
||||||
|
const blob = new Blob([value || ''], { type: 'text/markdown;charset=utf-8' });
|
||||||
|
const url = URL.createObjectURL(blob);
|
||||||
|
const a = document.createElement('a');
|
||||||
|
a.href = url;
|
||||||
|
a.download = 'document.md';
|
||||||
|
document.body.appendChild(a);
|
||||||
|
a.click();
|
||||||
|
document.body.removeChild(a);
|
||||||
|
URL.revokeObjectURL(url);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="markdown-editor">
|
||||||
|
{/* <div className="markdown-toolbar">
|
||||||
|
<button className="neon-button" onClick={handleCopy}>复制 Markdown</button>
|
||||||
|
<button className="neon-button" onClick={handleDownload}>下载 Markdown</button>
|
||||||
|
</div> */}
|
||||||
|
<div
|
||||||
|
ref={containerRef}
|
||||||
|
className="markdown-preview"
|
||||||
|
dangerouslySetInnerHTML={{ __html: sanitizedHtml }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
function SendRequestToBackend() {
|
||||||
|
const [inputValue, setInputValue] = useState('');
|
||||||
|
|
||||||
|
const handleSendRequest = async () => {
|
||||||
|
try {
|
||||||
|
const response = await fetch('http://localhost:8001/generate_survey', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
body: JSON.stringify({ query: inputValue }),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to send request');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
console.log('Response from backend:', data);
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error sending request:', error);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="request-panel" style={{ flexDirection: 'column', alignItems: 'center' }}>
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
value={inputValue}
|
||||||
|
onChange={(e) => setInputValue(e.target.value)}
|
||||||
|
className="neon-input"
|
||||||
|
placeholder="Enter text to send"
|
||||||
|
rows={3}
|
||||||
|
/>
|
||||||
|
<button onClick={handleSendRequest} className="neon-button">
|
||||||
|
Go!
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function App() {
|
||||||
|
const [inputs, setInputs] = useState({
|
||||||
|
query: { title: 'Query', displayText: '', targetText: '', isTyping: false },
|
||||||
|
nowUpdate: { title: 'Now Update', displayText: '', targetText: '', isTyping: false },
|
||||||
|
nextUpdate: { title: 'Next Update', displayText: '', targetText: '', isTyping: false },
|
||||||
|
searchKeywords: { title: 'Search Keywords', displayText: '', targetText: '', isTyping: false },
|
||||||
|
papers: { title: 'Papers', displayText: '', targetText: '', isTyping: false },
|
||||||
|
});
|
||||||
|
|
||||||
|
const [markdownContent, setMarkdownContent] = useState('');
|
||||||
|
|
||||||
|
const inputKeyMap = {
|
||||||
|
query: inputs.query,
|
||||||
|
nowUpdate: inputs.nowUpdate,
|
||||||
|
nextUpdate: inputs.nextUpdate,
|
||||||
|
searchKeywords: inputs.searchKeywords,
|
||||||
|
papers: inputs.papers,
|
||||||
|
markdown: markdownContent
|
||||||
|
};
|
||||||
|
|
||||||
|
const updateInputsFromPostData = (postData) => {
|
||||||
|
let newMarkdownContent = markdownContent;
|
||||||
|
|
||||||
|
Object.entries(postData).forEach(([key, value]) => {
|
||||||
|
if (key in inputKeyMap) {
|
||||||
|
if (key === 'markdown') {
|
||||||
|
if (markdownContent !== value) {
|
||||||
|
newMarkdownContent = value;
|
||||||
|
setMarkdownContent(newMarkdownContent);
|
||||||
|
}
|
||||||
|
} else if (inputKeyMap[key] && inputKeyMap[key].targetText !== value) {
|
||||||
|
const updatedInput = {
|
||||||
|
...inputKeyMap[key],
|
||||||
|
targetText: value,
|
||||||
|
isTyping: true,
|
||||||
|
};
|
||||||
|
setInputs((prevInputs) => ({
|
||||||
|
...prevInputs,
|
||||||
|
[key]: updatedInput,
|
||||||
|
}));
|
||||||
|
|
||||||
|
// startTypingAnimationForTextbox(value, (newText) => {
|
||||||
|
// setInputs((prevInputs) => ({
|
||||||
|
// ...prevInputs,
|
||||||
|
// [key]: {
|
||||||
|
// ...prevInputs[key],
|
||||||
|
// displayText: newText,
|
||||||
|
// },
|
||||||
|
// }));
|
||||||
|
// });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
// const startTypingAnimationForTextbox = (text, setText) => {
|
||||||
|
// setText('');
|
||||||
|
// let charIndex = 0;
|
||||||
|
// const timer = setInterval(() => {
|
||||||
|
// if (charIndex < text.length) {
|
||||||
|
// setText((prev) => prev + text[charIndex]);
|
||||||
|
// charIndex++;
|
||||||
|
// } else {
|
||||||
|
// clearInterval(timer);
|
||||||
|
// }
|
||||||
|
// }, 50); // Reduced interval for faster typing animation
|
||||||
|
// };
|
||||||
|
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const ws = new WebSocket('ws://localhost:8001/ws');
|
||||||
|
ws.onmessage = (event) => {
|
||||||
|
try {
|
||||||
|
const data = JSON.parse(event.data);
|
||||||
|
updateInputsFromPostData(data);
|
||||||
|
console.log('Received data:', data);
|
||||||
|
} catch (e) {
|
||||||
|
console.error('Invalid WebSocket message:', e);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
ws.onerror = (err) => {
|
||||||
|
console.error('WebSocket error:', err);
|
||||||
|
};
|
||||||
|
// return () => ws.close();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const leftInputs = [inputs.nowUpdate, inputs.nextUpdate,inputs.searchKeywords];
|
||||||
|
const rightInputs = [inputs.papers];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="cyber-container">
|
||||||
|
<div className="tech-panel left-panel">
|
||||||
|
{leftInputs.map((input, index) => (
|
||||||
|
<div key={`left-${index}`} className="input-wrapper">
|
||||||
|
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
|
||||||
|
<textarea
|
||||||
|
value={input.targetText}
|
||||||
|
readOnly
|
||||||
|
className="neon-input"
|
||||||
|
rows={Math.max(10, input.targetText.split('\n').length)}
|
||||||
|
cols={50}
|
||||||
|
style={{ resize: 'none', fontSize: '12px' }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="core-module">
|
||||||
|
<SendRequestToBackend />
|
||||||
|
<MarkdownEditor value={markdownContent} />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="tech-panel right-panel">
|
||||||
|
{rightInputs.map((input, index) => (
|
||||||
|
<div key={`right-${index}`} className="input-wrapper">
|
||||||
|
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
|
||||||
|
<textarea
|
||||||
|
value={input.targetText}
|
||||||
|
readOnly
|
||||||
|
className="neon-input"
|
||||||
|
rows={100}
|
||||||
|
cols={50}
|
||||||
|
style={{ resize: 'none', fontSize: '12px' }}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export default App;
|
||||||
68
code/frontend/minicpm4-survey/src/index.css
Normal file
68
code/frontend/minicpm4-survey/src/index.css
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
:root {
|
||||||
|
font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
|
||||||
|
line-height: 1.5;
|
||||||
|
font-weight: 400;
|
||||||
|
|
||||||
|
color-scheme: light dark;
|
||||||
|
color: rgba(255, 255, 255, 0.87);
|
||||||
|
background-color: #242424;
|
||||||
|
|
||||||
|
font-synthesis: none;
|
||||||
|
text-rendering: optimizeLegibility;
|
||||||
|
-webkit-font-smoothing: antialiased;
|
||||||
|
-moz-osx-font-smoothing: grayscale;
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
font-weight: 500;
|
||||||
|
color: #646cff;
|
||||||
|
text-decoration: inherit;
|
||||||
|
}
|
||||||
|
a:hover {
|
||||||
|
color: #535bf2;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
display: flex;
|
||||||
|
place-items: center;
|
||||||
|
min-width: 320px;
|
||||||
|
min-height: 100vh;
|
||||||
|
}
|
||||||
|
|
||||||
|
h1 {
|
||||||
|
font-size: 3.2em;
|
||||||
|
line-height: 1.1;
|
||||||
|
}
|
||||||
|
|
||||||
|
button {
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 1px solid transparent;
|
||||||
|
padding: 0.6em 1.2em;
|
||||||
|
font-size: 1em;
|
||||||
|
font-weight: 500;
|
||||||
|
font-family: inherit;
|
||||||
|
background-color: #1a1a1a;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: border-color 0.25s;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
border-color: #646cff;
|
||||||
|
}
|
||||||
|
button:focus,
|
||||||
|
button:focus-visible {
|
||||||
|
outline: 4px auto -webkit-focus-ring-color;
|
||||||
|
}
|
||||||
|
|
||||||
|
@media (prefers-color-scheme: light) {
|
||||||
|
:root {
|
||||||
|
color: #213547;
|
||||||
|
background-color: #ffffff;
|
||||||
|
}
|
||||||
|
a:hover {
|
||||||
|
color: #747bff;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
background-color: #f9f9f9;
|
||||||
|
}
|
||||||
|
}
|
||||||
10
code/frontend/minicpm4-survey/src/main.jsx
Normal file
10
code/frontend/minicpm4-survey/src/main.jsx
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
import { StrictMode } from 'react'
|
||||||
|
import { createRoot } from 'react-dom/client'
|
||||||
|
import './index.css'
|
||||||
|
import App from './App.jsx'
|
||||||
|
|
||||||
|
createRoot(document.getElementById('root')).render(
|
||||||
|
<StrictMode>
|
||||||
|
<App />
|
||||||
|
</StrictMode>,
|
||||||
|
)
|
||||||
7
code/frontend/minicpm4-survey/vite.config.js
Normal file
7
code/frontend/minicpm4-survey/vite.config.js
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
import { defineConfig } from 'vite'
|
||||||
|
import react from '@vitejs/plugin-react'
|
||||||
|
|
||||||
|
// https://vite.dev/config/
|
||||||
|
export default defineConfig({
|
||||||
|
plugins: [react()],
|
||||||
|
})
|
||||||
8
code/requirements.txt
Normal file
8
code/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
openai
|
||||||
|
vllm
|
||||||
|
jsonlines
|
||||||
|
faiss-cpu
|
||||||
|
# faiss-gpu
|
||||||
|
fastapi
|
||||||
|
uvicorn
|
||||||
|
yarl
|
||||||
4
code/scripts/run.sh
Normal file
4
code/scripts/run.sh
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
python ./src/generation/run.py \
|
||||||
|
--model_path "openbmb/MiniCPM4-Survey" \
|
||||||
|
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
|
||||||
|
--output_file "test.md" \
|
||||||
5
code/scripts/run_with_frontend.sh
Normal file
5
code/scripts/run_with_frontend.sh
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
python ./src/generation/run.py \
|
||||||
|
--model_path "openbmb/MiniCPM4-Survey" \
|
||||||
|
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
|
||||||
|
--output_file "test.md" \
|
||||||
|
--port 8001 \
|
||||||
810
code/src/generation/buffer.py
Normal file
810
code/src/generation/buffer.py
Normal file
@@ -0,0 +1,810 @@
|
|||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# BASE_SURVEY_STRUCTURE = """
|
||||||
|
# # Title: A survey of ...
|
||||||
|
# # Introduction: None.
|
||||||
|
# # Section 1: None.
|
||||||
|
# ## Subsection 1 (if needed): None.
|
||||||
|
# ## Subsection 2 (if needed): None.
|
||||||
|
# ### Subsubsection 1 (if needed): None.
|
||||||
|
# ### Subsubsection 2 (if needed): None.
|
||||||
|
# ### ...
|
||||||
|
# # Section 2: None.
|
||||||
|
# # ...
|
||||||
|
# # Conclusion: None.
|
||||||
|
# """
|
||||||
|
|
||||||
|
|
||||||
|
class SurveyManager:
|
||||||
|
BASE_SURVEY_STRUCTURE = {
|
||||||
|
"title": "",
|
||||||
|
"abstract": "",
|
||||||
|
"introduction": {
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
"sections": [],
|
||||||
|
"conclusion": ""
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_update_pos(update_pos):
|
||||||
|
"""
|
||||||
|
(1) "title", "abstract", "introduction", or "conclusion"
|
||||||
|
(2) "section-i/subsection-j/..."
|
||||||
|
|
||||||
|
"""
|
||||||
|
if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]:
|
||||||
|
return update_pos
|
||||||
|
else:
|
||||||
|
keys = update_pos.split("/")
|
||||||
|
if len(keys) == 1: # Section-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])
|
||||||
|
return f"section-{i}"
|
||||||
|
elif len(keys) == 2: # Section-?/Subsection-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])
|
||||||
|
j = int(keys[1].lower().split("subsection-")[-1])
|
||||||
|
return f"section-{i}/subsection-{j}"
|
||||||
|
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])
|
||||||
|
j = int(keys[1].lower().split("subsection-")[-1])
|
||||||
|
k = int(keys[2].lower().split("subsubsection-")[-1])
|
||||||
|
return f"section-{i}/subsection-{j}/subsubsection-{k}"
|
||||||
|
else:
|
||||||
|
raise ValueError("unsupported update_pos keys")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_one_line(string):
|
||||||
|
if isinstance(string, dict):
|
||||||
|
if "content" in string and string["content"]:
|
||||||
|
return SurveyManager._to_one_line(string["content"])
|
||||||
|
# return SurveyManager._to_one_line(string["content"])
|
||||||
|
else:
|
||||||
|
return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip()
|
||||||
|
if not string:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return string#.replace("\n", " ")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_survey_dict_to_str(current_survey):
|
||||||
|
string = ""
|
||||||
|
if current_survey == {}:
|
||||||
|
return "There is no survey."
|
||||||
|
# title
|
||||||
|
try:
|
||||||
|
content = SurveyManager._to_one_line(current_survey["title"])
|
||||||
|
string += f"# {content}\n"
|
||||||
|
except:
|
||||||
|
string += f"# Title: None\n"
|
||||||
|
|
||||||
|
# abstract
|
||||||
|
try:
|
||||||
|
content = SurveyManager._to_one_line(current_survey["abstract"])
|
||||||
|
string += f"## Abstract\n{content}\n"
|
||||||
|
except:
|
||||||
|
string += f"## Abstract\nNone\n"
|
||||||
|
|
||||||
|
# introduction
|
||||||
|
try:
|
||||||
|
content = SurveyManager._to_one_line(current_survey["introduction"])
|
||||||
|
string += f"## Introduction\n{content}\n"
|
||||||
|
except:
|
||||||
|
string += f"## Introduction\nNone\n"
|
||||||
|
|
||||||
|
# sections
|
||||||
|
if "sections" in current_survey:
|
||||||
|
for i, section in enumerate(current_survey["sections"]):
|
||||||
|
title_key = "name" if "name" in section else "title"
|
||||||
|
name, content = section[title_key], SurveyManager._to_one_line(section)
|
||||||
|
# string += f"# Section-{i+1} [{name}]: {content}\n"
|
||||||
|
string += f"## {name}\n{content}\n"
|
||||||
|
|
||||||
|
if "subsections" in section:
|
||||||
|
for j, subsection in enumerate(section["subsections"]):
|
||||||
|
name, content = subsection[title_key], SurveyManager._to_one_line(subsection)
|
||||||
|
# string += f" ## Subsection-{j+1} [{name}]: {content}\n"
|
||||||
|
string += f"### {name}\n{content}\n"
|
||||||
|
|
||||||
|
if "subsubsections" in subsection:
|
||||||
|
for k, subsubsection in enumerate(subsection["subsubsections"]):
|
||||||
|
name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection)
|
||||||
|
# string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
|
||||||
|
string += f"#### {name}\n{content}\n"
|
||||||
|
|
||||||
|
|
||||||
|
# conclusion
|
||||||
|
try:
|
||||||
|
content = SurveyManager._to_one_line(current_survey["conclusion"])
|
||||||
|
string += f"## Conclusion\n{content}\n"
|
||||||
|
except:
|
||||||
|
string += f"## Conclusion:\nNone\n"
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _abbr_one_line(string, abbr=True):
|
||||||
|
if isinstance(string, dict):
|
||||||
|
if "content" in string and string["content"]:
|
||||||
|
return SurveyManager._abbr_one_line(string["content"], abbr=abbr)
|
||||||
|
elif "plan" in string:
|
||||||
|
return "[PLAN] " + string["plan"].replace("\n", " ").strip()
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
if not string:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
if abbr and len(string) > 50:
|
||||||
|
return "[OK] " + string.replace("\n", " ").strip()[:50] + "..."
|
||||||
|
else:
|
||||||
|
return "[OK] " + string.replace("\n", " ").strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_survey_dict_to_abbr_str(current_survey):
|
||||||
|
string = ""
|
||||||
|
if current_survey == {}:
|
||||||
|
return "There is no survey."
|
||||||
|
# title
|
||||||
|
try:
|
||||||
|
content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False)
|
||||||
|
string += f"# Title: {content}\n"
|
||||||
|
except:
|
||||||
|
string += f"# Title: None\n"
|
||||||
|
# abstract
|
||||||
|
try:
|
||||||
|
content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False)
|
||||||
|
string += f"# Abstract: {content}\n"
|
||||||
|
except:
|
||||||
|
string += f"# Abstract: None\n"
|
||||||
|
|
||||||
|
# introduction
|
||||||
|
try:
|
||||||
|
content = SurveyManager._abbr_one_line(current_survey["introduction"])
|
||||||
|
string += f"# Introduction: {content}\n"
|
||||||
|
except:
|
||||||
|
string += f"# Introduction: None\n"
|
||||||
|
|
||||||
|
# sections
|
||||||
|
if "sections" in current_survey:
|
||||||
|
for i, section in enumerate(current_survey["sections"]):
|
||||||
|
title_key = "name" if "name" in section else "title"
|
||||||
|
name, content = section[title_key], SurveyManager._abbr_one_line(section)
|
||||||
|
string += f"# Section-{i+1} [{name}]: {content}\n"
|
||||||
|
|
||||||
|
if "subsections" in section:
|
||||||
|
for j, subsection in enumerate(section["subsections"]):
|
||||||
|
name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection)
|
||||||
|
string += f" ## Subsection-{j+1} [{name}]: {content}\n"
|
||||||
|
|
||||||
|
if "subsubsections" in subsection:
|
||||||
|
for k, subsubsection in enumerate(subsection["subsubsections"]):
|
||||||
|
name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection)
|
||||||
|
string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
|
||||||
|
|
||||||
|
# conclusion
|
||||||
|
try:
|
||||||
|
content = SurveyManager._abbr_one_line(current_survey["conclusion"])
|
||||||
|
string += f"# Conclusion: {content}\n"
|
||||||
|
except:
|
||||||
|
string += f"# Conclusion: None\n"
|
||||||
|
|
||||||
|
return string
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_one_section(sections, i, content):
|
||||||
|
# i -= 1
|
||||||
|
if i >= 0 and i <= (len(sections)-1):
|
||||||
|
sections[i]["content"] = content
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def update_current_survey(current_survey, answer) -> bool:
|
||||||
|
"""
|
||||||
|
update_pos: "section-i/subsection-j/subsubsection-k"
|
||||||
|
"""
|
||||||
|
# if answer == {}:
|
||||||
|
# return True
|
||||||
|
try:
|
||||||
|
update_pos, content = answer["update"], answer["content"]
|
||||||
|
|
||||||
|
if update_pos == "plan":
|
||||||
|
# current_survey = content
|
||||||
|
if current_survey == {}:
|
||||||
|
for k,v in content.items():
|
||||||
|
current_survey[k] = copy.deepcopy(v)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
elif update_pos in ["conclusion", "abstract"]:
|
||||||
|
if update_pos not in current_survey:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
current_survey[update_pos] = content
|
||||||
|
|
||||||
|
elif update_pos == "introduction":
|
||||||
|
if update_pos not in current_survey:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
current_survey[update_pos] = {"content": content}
|
||||||
|
|
||||||
|
else:
|
||||||
|
keys = update_pos.split("/")
|
||||||
|
if len(keys) == 1: # Section-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])-1
|
||||||
|
return SurveyManager.update_one_section(current_survey["sections"], i, content)
|
||||||
|
|
||||||
|
elif len(keys) == 2: # Section-?/Subsection-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])-1
|
||||||
|
j = int(keys[1].lower().split("subsection-")[-1])-1
|
||||||
|
try:
|
||||||
|
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content)
|
||||||
|
except:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
|
||||||
|
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
|
||||||
|
i = int(keys[0].lower().split("section-")[-1])-1
|
||||||
|
j = int(keys[1].lower().split("subsection-")[-1])-1
|
||||||
|
k = int(keys[2].lower().split("subsubsection-")[-1])-1
|
||||||
|
try:
|
||||||
|
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) # 禁用第四级
|
||||||
|
except:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# print("update fail!")
|
||||||
|
# print("unsupported update_pos keys")
|
||||||
|
return False
|
||||||
|
# raise ValueError("unsupported update_pos keys")
|
||||||
|
except:
|
||||||
|
# print("update fail!")
|
||||||
|
return False
|
||||||
|
# print("answer is not a valid json object.")
|
||||||
|
# print(answer)
|
||||||
|
# raise ValueError("answer is not a valid json object.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
from prompts import *
|
||||||
|
class PromptManger:
|
||||||
|
system_prompt = SYSTEM_PROMPT_0415_BUFFER
|
||||||
|
user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER
|
||||||
|
user_prompt = USER_PROMPT_0415_BUFFER
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BufferManager:
|
||||||
|
"""
|
||||||
|
Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training.
|
||||||
|
batch_rollout_data = [
|
||||||
|
{
|
||||||
|
query (or env_id): # Uniquely identifies a query or environment, [input parameter].
|
||||||
|
*running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat.
|
||||||
|
state: { # Indicates whether the process is finished.
|
||||||
|
"score": 0.0,
|
||||||
|
"done": True / False
|
||||||
|
"current_survey": dict # Structured data.
|
||||||
|
}
|
||||||
|
trajectory: [ # Organizes all data into a multi-turn interaction format.
|
||||||
|
{
|
||||||
|
step: int, 0~?, # The first step, usually includes some init_info or plan.
|
||||||
|
original_response: str, The raw output from the model, which may have various formatting issues.
|
||||||
|
answer_thought: str, # Encapsulated using the <think>...</think> block.
|
||||||
|
answer: {
|
||||||
|
"original_str": str
|
||||||
|
"update": str,
|
||||||
|
"name": str,
|
||||||
|
"content": str,
|
||||||
|
"inclusions": list, # Extracted independently?
|
||||||
|
}
|
||||||
|
tool_call_thought: str, # Encapsulated using the <think>...</think> block.
|
||||||
|
tool_call: {
|
||||||
|
"original_str": str, # Encapsulated using the <tool_call>...</tool_call> block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search".
|
||||||
|
"tool_name": str # done or search.
|
||||||
|
"keywords": list[str], Extracted search keywords from tool_call, otherwise none.
|
||||||
|
}
|
||||||
|
*papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed.
|
||||||
|
cites: list[str], # References cited by the model, which may include multiple citations.
|
||||||
|
summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY.
|
||||||
|
*prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed.
|
||||||
|
},
|
||||||
|
...
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, prompts, repeat_n: int=1):
|
||||||
|
# self.config = config
|
||||||
|
self.step = 0
|
||||||
|
self.batch_rollout_data = []
|
||||||
|
self.running_ids = [] # active envs
|
||||||
|
batch_size = prompts.batch['input_ids'].size(0)
|
||||||
|
uids = prompts.non_tensor_batch['uid']
|
||||||
|
querys = prompts.non_tensor_batch['raw_prompt'].copy()
|
||||||
|
ground_truths = prompts.non_tensor_batch['ground_truth']
|
||||||
|
# print(querys)
|
||||||
|
new_querys = []
|
||||||
|
for i_batch in range(batch_size):
|
||||||
|
raw_prompt_i_batch = querys[i_batch][-1]["content"]
|
||||||
|
new_querys.append(raw_prompt_i_batch)
|
||||||
|
querys = new_querys
|
||||||
|
|
||||||
|
assert len(querys) == len(uids)
|
||||||
|
for query, uid, ground_truth in zip(querys, uids, ground_truths):
|
||||||
|
|
||||||
|
now_survey = {}
|
||||||
|
|
||||||
|
for _ in range(repeat_n):
|
||||||
|
self.batch_rollout_data.append({
|
||||||
|
"query": query,
|
||||||
|
"uid": uid,
|
||||||
|
"state": {
|
||||||
|
# "score": 0.0, # only for debug
|
||||||
|
# "format_score": None, # will update at last step
|
||||||
|
"done": False,
|
||||||
|
"current_survey": {}
|
||||||
|
},
|
||||||
|
"trajectory": [],
|
||||||
|
"history_messages": [],
|
||||||
|
})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_system_prompt():
|
||||||
|
prompt = PromptManger.system_prompt
|
||||||
|
return prompt
|
||||||
|
@staticmethod
|
||||||
|
def _build_user_prompt_v0(query, current_survey):
|
||||||
|
# query
|
||||||
|
prompt = PromptManger.user_prompt_v0.replace("<user_query>", query)
|
||||||
|
|
||||||
|
# add template
|
||||||
|
prompt = prompt.replace("<init_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_user_prompt(query, current_survey, trajs):
|
||||||
|
last_traj = trajs[-1]
|
||||||
|
# query
|
||||||
|
prompt = PromptManger.user_prompt.replace("<user_query>", query)
|
||||||
|
|
||||||
|
# add current survey
|
||||||
|
prompt = prompt.replace("<current_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
|
||||||
|
|
||||||
|
# current plan
|
||||||
|
if last_traj["tool_call_thought"] == "":
|
||||||
|
prompt = prompt.replace("<last_step_thought>", "Your last thought is not available, please give new plan")
|
||||||
|
else:
|
||||||
|
prompt = prompt.replace("<last_step_thought>", last_traj["tool_call_thought"])
|
||||||
|
prompt = prompt.replace("<last_step_tool_call>", json.dumps(last_traj["tool_call"]))
|
||||||
|
|
||||||
|
# summarys
|
||||||
|
for traj in reversed(trajs):
|
||||||
|
if len(traj["summarys"]) > 0:
|
||||||
|
break
|
||||||
|
summary_num = len(traj["summarys"])
|
||||||
|
|
||||||
|
if summary_num == 0:
|
||||||
|
prompt = prompt.replace("<summarys>", "There is no result.")
|
||||||
|
else:
|
||||||
|
prompt = prompt.replace("<summarys>", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"]))
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_user_prompt_force_correct(query, current_survey, trajs):
|
||||||
|
if current_survey == {}:
|
||||||
|
# gen plan
|
||||||
|
now_section = "plan"
|
||||||
|
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
|
||||||
|
else:
|
||||||
|
now_section = ""
|
||||||
|
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
|
||||||
|
now_section = "abstract"
|
||||||
|
elif "content" not in current_survey["introduction"]:
|
||||||
|
now_section = "introduction"
|
||||||
|
elif "sections" in current_survey:
|
||||||
|
for section in current_survey["sections"]:
|
||||||
|
if "content" not in section:
|
||||||
|
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
|
||||||
|
break
|
||||||
|
elif "subsections" in section:
|
||||||
|
for subsection in section["subsections"]:
|
||||||
|
if "content" not in subsection:
|
||||||
|
now_section = "section-{}/subsection-{}".format(
|
||||||
|
current_survey["sections"].index(section) + 1,
|
||||||
|
section["subsections"].index(subsection) + 1
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif "subsubsections" in subsection:
|
||||||
|
for subsubsection in subsection["subsubsections"]:
|
||||||
|
if "content" not in subsubsection:
|
||||||
|
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
|
||||||
|
current_survey["sections"].index(section) + 1,
|
||||||
|
section["subsections"].index(subsection) + 1,
|
||||||
|
subsection["subsubsections"].index(subsubsection) + 1
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if now_section:
|
||||||
|
break
|
||||||
|
if now_section:
|
||||||
|
break
|
||||||
|
|
||||||
|
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
|
||||||
|
now_section = "conclusion"
|
||||||
|
else:
|
||||||
|
trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
|
||||||
|
if now_section != "":
|
||||||
|
trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}"
|
||||||
|
for traj in reversed(trajs):
|
||||||
|
if len(traj["summarys"]) > 0:
|
||||||
|
break
|
||||||
|
summary_num = len(traj["summarys"])
|
||||||
|
if now_section == "plan" and summary_num == 0:
|
||||||
|
trajs[-1]["tool_call_thought"] = "I need to get enough information."
|
||||||
|
|
||||||
|
return BufferManager._build_user_prompt(query, current_survey, trajs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_finalize(query, current_survey, trajs):
|
||||||
|
if current_survey == {}:
|
||||||
|
# gen plan
|
||||||
|
return False
|
||||||
|
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
|
||||||
|
else:
|
||||||
|
now_section = ""
|
||||||
|
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
|
||||||
|
now_section = "abstract"
|
||||||
|
elif "content" not in current_survey["introduction"]:
|
||||||
|
now_section = "introduction"
|
||||||
|
elif "sections" in current_survey:
|
||||||
|
for section in current_survey["sections"]:
|
||||||
|
if "content" not in section:
|
||||||
|
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
|
||||||
|
break
|
||||||
|
elif "subsections" in section:
|
||||||
|
for subsection in section["subsections"]:
|
||||||
|
if "content" not in subsection:
|
||||||
|
now_section = "section-{}/subsection-{}".format(
|
||||||
|
current_survey["sections"].index(section) + 1,
|
||||||
|
section["subsections"].index(subsection) + 1
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif "subsubsections" in subsection:
|
||||||
|
for subsubsection in subsection["subsubsections"]:
|
||||||
|
if "content" not in subsubsection:
|
||||||
|
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
|
||||||
|
current_survey["sections"].index(section) + 1,
|
||||||
|
section["subsections"].index(subsection) + 1,
|
||||||
|
subsection["subsubsections"].index(subsubsection) + 1
|
||||||
|
)
|
||||||
|
break
|
||||||
|
if now_section:
|
||||||
|
break
|
||||||
|
if now_section:
|
||||||
|
break
|
||||||
|
|
||||||
|
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
|
||||||
|
now_section = "conclusion"
|
||||||
|
# else:
|
||||||
|
# trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
|
||||||
|
if now_section != "":
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# rule-based method: query, plan, paragraphs -> prompt -> thought, paragraph, action
|
||||||
|
def build_prompt_for_generator(self):
|
||||||
|
total_messages = []
|
||||||
|
self.running_ids = []
|
||||||
|
for running_id, data in enumerate(self.batch_rollout_data):
|
||||||
|
if data["state"]["done"]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if len(data["trajectory"]) == 0: # first prompt
|
||||||
|
user_prompt = BufferManager._build_user_prompt_v0(data["query"],
|
||||||
|
data["state"]["current_survey"])
|
||||||
|
else:
|
||||||
|
if data["trajectory"][-1]["update_success"]:
|
||||||
|
user_prompt = BufferManager._build_user_prompt(data["query"],
|
||||||
|
data["state"]["current_survey"],
|
||||||
|
data["trajectory"])
|
||||||
|
else:
|
||||||
|
# user_prompt = data["history_messages"][-1][1]["content"]
|
||||||
|
user_prompt = BufferManager._build_user_prompt_force_correct(data["query"],
|
||||||
|
data["state"]["current_survey"],
|
||||||
|
data["trajectory"])
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": BufferManager._build_system_prompt(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_prompt,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
data["history_messages"].append(messages)
|
||||||
|
total_messages.append(messages)
|
||||||
|
self.running_ids.append(running_id) # update running ids
|
||||||
|
return total_messages
|
||||||
|
|
||||||
|
def update_all_scores(self, scores):
|
||||||
|
assert len(scores) == len(self.batch_rollout_data)
|
||||||
|
for score, log in zip(scores, self.batch_rollout_data):
|
||||||
|
log["state"]["score"] = score
|
||||||
|
|
||||||
|
def update_all_format_scores(self, scores):
|
||||||
|
assert len(scores) == len(self.batch_rollout_data)
|
||||||
|
for score, log in zip(scores, self.batch_rollout_data):
|
||||||
|
log["state"]["format_score"] = score
|
||||||
|
|
||||||
|
|
||||||
|
def update_trajectory(self, model_responses, env_feedbacks):
|
||||||
|
"""
|
||||||
|
model_response: original_response, thought, paragraph, tool_call, format_reward
|
||||||
|
env_feedback: done, search_keywards, abstracts, outcome_reward
|
||||||
|
"""
|
||||||
|
assert len(self.running_ids) == len(model_responses)
|
||||||
|
assert len(self.running_ids) == len(env_feedbacks)
|
||||||
|
|
||||||
|
for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks):
|
||||||
|
# update state
|
||||||
|
self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] # if True, finalize the task
|
||||||
|
|
||||||
|
update_success = False
|
||||||
|
if response["true"]:
|
||||||
|
if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}:
|
||||||
|
if len(response["answer"]) != 0: # no empty dict or start
|
||||||
|
update_success = SurveyManager.update_current_survey(
|
||||||
|
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||||
|
response["answer"])
|
||||||
|
else:
|
||||||
|
# Search Then Write
|
||||||
|
if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]:
|
||||||
|
update_success = SurveyManager.update_current_survey(
|
||||||
|
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||||
|
response["answer"])
|
||||||
|
elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0:
|
||||||
|
update_success = True
|
||||||
|
|
||||||
|
|
||||||
|
self.batch_rollout_data[running_id]["trajectory"].append({
|
||||||
|
"step": self.step,
|
||||||
|
"original_response": response["original_response"],
|
||||||
|
"answer_thought": response["answer_thought"],
|
||||||
|
"answer": response["answer"],
|
||||||
|
"tool_call_thought": response["tool_call_thought"],
|
||||||
|
"tool_call": response["tool_call"],
|
||||||
|
"search_keywords": feedback["search_keywords"],
|
||||||
|
"summarys": feedback["summarys"],
|
||||||
|
"update_success": update_success and response["true"],
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
self.batch_rollout_data[running_id]["history_messages"][-1].append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response["original_response"],
|
||||||
|
})
|
||||||
|
|
||||||
|
if self.batch_rollout_data[running_id]["state"]["done"]:
|
||||||
|
real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"],
|
||||||
|
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||||
|
self.batch_rollout_data[running_id]["trajectory"])
|
||||||
|
if not real_done:
|
||||||
|
self.batch_rollout_data[running_id]["state"]["done"] = False
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def match_reference(text:str):
|
||||||
|
reg = r"\\\w*cite(?!style)\w*\{(.+?)\}"
|
||||||
|
placeholder_reg = re.compile(r"^#\d+$")
|
||||||
|
reg_bibkeys = re.findall(reg, text)
|
||||||
|
bibkeys = set()
|
||||||
|
for bibkey in reg_bibkeys:
|
||||||
|
single_bib = bibkey.split(",")
|
||||||
|
for bib in single_bib:
|
||||||
|
if not placeholder_reg.match(bib):
|
||||||
|
bib = bib.strip()
|
||||||
|
if bib and bib != "*":
|
||||||
|
bibkeys.add(bib)
|
||||||
|
|
||||||
|
reg = r"\\nocite{(.+?)\}"
|
||||||
|
reg_bibkeys = re.findall(reg, text)
|
||||||
|
for bibkey in reg_bibkeys:
|
||||||
|
single_bib = bibkey.split(",")
|
||||||
|
for bib in single_bib:
|
||||||
|
if not placeholder_reg.match(bib):
|
||||||
|
bib = bib.strip()
|
||||||
|
if bib and bib != "*":
|
||||||
|
bibkeys.remove(bib)
|
||||||
|
|
||||||
|
ref_key_list = list(bibkeys)
|
||||||
|
return ref_key_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_generator_response(response):
|
||||||
|
"""
|
||||||
|
1. 解析失败: step + 1, 重新生成, 给出提示
|
||||||
|
2. 解析成功:
|
||||||
|
2.1 tool_call == search(keywords) 发送post请求
|
||||||
|
2.2 tool_call == done 结束任务
|
||||||
|
|
||||||
|
**standard format**
|
||||||
|
|
||||||
|
Current Update:
|
||||||
|
<think> [Your Thoughts]: str </think>
|
||||||
|
<answer> {"update": str, "content": str}: dict </answer>
|
||||||
|
|
||||||
|
Next Plan:
|
||||||
|
<think> [Your Thoughts]: str </think>
|
||||||
|
<tool_call> {"tool": "search", "arguments": {}}: dict</tool_call>
|
||||||
|
"""
|
||||||
|
extracted_result = {
|
||||||
|
"original_response": response
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
current_update = response.split("Current Update:")[-1].split("Next Plan:")[0]
|
||||||
|
except:
|
||||||
|
current_update = response
|
||||||
|
|
||||||
|
# pattern
|
||||||
|
think_pattern = r"<think>(.*?)</think>"
|
||||||
|
answer_pattern = r"<answer>(.*?)</answer>"
|
||||||
|
tool_pattern = r"<tool_call>(.*?)</tool_call>"
|
||||||
|
|
||||||
|
# extract information from current_update
|
||||||
|
|
||||||
|
think_match = re.search(think_pattern, current_update, re.DOTALL) # 多行提取
|
||||||
|
if think_match:
|
||||||
|
think = think_match.group(1)
|
||||||
|
think = think.strip()
|
||||||
|
else:
|
||||||
|
think = ""
|
||||||
|
extracted_result["answer_thought"] = think
|
||||||
|
|
||||||
|
answer_match = re.search(answer_pattern, current_update, re.DOTALL) # 多行提取
|
||||||
|
has_answer = False
|
||||||
|
if answer_match:
|
||||||
|
answer = answer_match.group(1)
|
||||||
|
answer = answer.strip()
|
||||||
|
try:
|
||||||
|
answer = json.loads(answer)
|
||||||
|
if not answer == {}:
|
||||||
|
assert isinstance(answer["update"], str)
|
||||||
|
answer["update"] = SurveyManager.parse_update_pos(answer["update"])
|
||||||
|
if answer["update"] == "plan":
|
||||||
|
|
||||||
|
assert isinstance(answer["content"], dict)
|
||||||
|
plan = answer["content"]
|
||||||
|
assert isinstance(plan, dict)
|
||||||
|
plan.pop("instruction",None)
|
||||||
|
keys = ["abstract", "introduction", "conclusion","sections","title"]
|
||||||
|
for key in keys:
|
||||||
|
assert key in plan
|
||||||
|
for key in plan:
|
||||||
|
assert key in keys
|
||||||
|
if key == "sections":
|
||||||
|
assert isinstance(plan[key], list)
|
||||||
|
for section in plan[key]:
|
||||||
|
assert isinstance(section, dict)
|
||||||
|
assert "plan" in section
|
||||||
|
assert "title" in section
|
||||||
|
assert isinstance(section["plan"], str)
|
||||||
|
assert isinstance(section["title"], str)
|
||||||
|
assert section["title"] != "Methodology" # 不能是Methodology,WIP
|
||||||
|
if "subsections" in section:
|
||||||
|
assert isinstance(section["subsections"], list)
|
||||||
|
for subsection in section["subsections"]:
|
||||||
|
assert isinstance(subsection, dict)
|
||||||
|
assert "plan" in subsection
|
||||||
|
assert "title" in subsection
|
||||||
|
assert isinstance(subsection["plan"], str)
|
||||||
|
assert isinstance(subsection["title"], str)
|
||||||
|
if "subsubsections" in section:
|
||||||
|
assert isinstance(subsection["subsubsections"], list)
|
||||||
|
for subsubsection in subsection["subsubsections"]:
|
||||||
|
assert isinstance(subsubsection, dict)
|
||||||
|
assert "plan" in subsubsection
|
||||||
|
assert "title" in subsubsection
|
||||||
|
assert isinstance(subsubsection["plan"], str)
|
||||||
|
assert isinstance(subsubsection["title"], str)
|
||||||
|
elif key == "title":
|
||||||
|
assert isinstance(plan[key], str)
|
||||||
|
else:
|
||||||
|
assert isinstance(plan[key], dict)
|
||||||
|
assert "plan" in plan[key]
|
||||||
|
if key not in ["abstract", "conclusion", "introduction"]:
|
||||||
|
assert "title" in plan[key]
|
||||||
|
else:
|
||||||
|
assert isinstance(answer["content"], str)
|
||||||
|
has_answer = True
|
||||||
|
except:
|
||||||
|
answer = {}
|
||||||
|
else:
|
||||||
|
answer = {}
|
||||||
|
extracted_result["answer"] = answer
|
||||||
|
|
||||||
|
# extract information from next_plan
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_plan = response.split("Next Plan:")[1]
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
next_plan = response.split("</answer>")[1]
|
||||||
|
except:
|
||||||
|
next_plan = response
|
||||||
|
|
||||||
|
think_match = re.search(think_pattern, next_plan, re.DOTALL) # 多行提取
|
||||||
|
if think_match:
|
||||||
|
think = think_match.group(1)
|
||||||
|
think = think.strip()
|
||||||
|
else:
|
||||||
|
think = ""
|
||||||
|
extracted_result["tool_call_thought"] = think
|
||||||
|
|
||||||
|
tool_match = re.search(tool_pattern, next_plan, re.DOTALL) # 多行提取
|
||||||
|
has_tool_call = False
|
||||||
|
if tool_match:
|
||||||
|
tool_text = tool_match.group(1)
|
||||||
|
tool_text = tool_text.strip()
|
||||||
|
try:
|
||||||
|
tool_call = json.loads(tool_text)
|
||||||
|
assert tool_call["name"] in ["search_engine", "finalize"]
|
||||||
|
if tool_call["name"] == "search_engine":
|
||||||
|
assert isinstance(tool_call["arguments"]["query"], list)
|
||||||
|
has_tool_call = True
|
||||||
|
except:
|
||||||
|
tool_call = {}
|
||||||
|
else:
|
||||||
|
|
||||||
|
tool_call = {}
|
||||||
|
|
||||||
|
extracted_result["tool_call"] = tool_call
|
||||||
|
|
||||||
|
extracted_result["true"] = has_answer and has_tool_call
|
||||||
|
reg = r"[\u4e00-\u9fa5]"
|
||||||
|
has_chinese = re.search(reg, response) is not None
|
||||||
|
extracted_result["true"] = extracted_result["true"] and not has_chinese
|
||||||
|
|
||||||
|
return extracted_result
|
||||||
|
|
||||||
|
|
||||||
|
class BufferManager_V2(BufferManager):
|
||||||
|
|
||||||
|
def __init__(self, querys, repeat_n=1):
|
||||||
|
# self.config = config
|
||||||
|
self.step = 0
|
||||||
|
self.batch_rollout_data = []
|
||||||
|
self.running_ids = [] # active envs
|
||||||
|
|
||||||
|
for uid, query in enumerate(querys):
|
||||||
|
print("CURRENT QUERY: ", query)
|
||||||
|
for _ in range(repeat_n):
|
||||||
|
self.batch_rollout_data.append({
|
||||||
|
"query": query,
|
||||||
|
"uid": f"query_{uid}",
|
||||||
|
"state": {
|
||||||
|
# "score": 0.0, # only for debug
|
||||||
|
# "format_score": None, # will update at last step
|
||||||
|
"done": False,
|
||||||
|
"current_survey": {}
|
||||||
|
},
|
||||||
|
"trajectory": [],
|
||||||
|
"history_messages": []
|
||||||
|
})
|
||||||
|
|
||||||
81
code/src/generation/prompts.py
Normal file
81
code/src/generation/prompts.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_0415_BUFFER = """You are a survey writer. You are asked to write a survey follow the instruction, refered as "Query" or "User's Query". You will finish the survey by multi-step updating.
|
||||||
|
|
||||||
|
Usually, you need to do two things:
|
||||||
|
(1) First, you need to update the survey using the retrieved information according to the current plan, refered as "Current Update". You MUST think inside <think>...</think> before you give your <answer>...</answer> action, mainly about "How to write paragraphs with citations based on retrieved information to complete the current plan?". If the current plan is None, or you think the current plan is not good, or you think the retrieved information is not enough for you to finish the plan, you can jump the "Answer" action by giving "{}" as answer. Please give the citation in \\cite{}.
|
||||||
|
|
||||||
|
(2) Then, you need decide what part of the survey needs to be updated, refered as "Next Plan". You MUST think inside <think>...</think> before you give your <tool_call>...</tool_call> action. If you think the current retrieved information is enough to finish your next plan, you can jump the "Tool Call" action by giving "{}" as tool call.
|
||||||
|
|
||||||
|
## Answer
|
||||||
|
You can give one answer to update the survey.
|
||||||
|
<answer>
|
||||||
|
{"update": <section-pos>, "content": paragraph }
|
||||||
|
</answer>
|
||||||
|
|
||||||
|
There are two parameters in <answer> action.
|
||||||
|
* update: string, which position you want to update, such as "title", "abstract", "introduction", "section-1", "section-1/subsction-1", "section-1/subsction-1/subsection-1", and "conclusion".
|
||||||
|
* content: string, the update content for the position of the survey, please give the faithful citation in \\cite{}. . Or dict, only when you give the plan of the paper, the values including the section title and a simple plan of it.
|
||||||
|
|
||||||
|
## Tool Call
|
||||||
|
You can call one function to assist the survey writing.
|
||||||
|
|
||||||
|
You are provided with function signatures within <tools></tools> XML tags:
|
||||||
|
<tools>
|
||||||
|
{"type": "function", "function": {"name": "search_engine", "description": "Search reasearch papers.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "description": "The words to search for in quotes."}, "required": ["query"]}}}}
|
||||||
|
{"type": "function", "function": {"name": "finalize", "description": "Finalize the survey.", "parameters": {}, "required": []}}
|
||||||
|
</tools>
|
||||||
|
|
||||||
|
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||||
|
<tool_call> {"name": <function-name>, "arguments": <args-json-object>} </tool_call>
|
||||||
|
|
||||||
|
For example, You can call the search engine by using:
|
||||||
|
<tool_call> {"name": "search_engine", "arguments": {"query": ["keyword-1", "keyword-2", ...]} </tool_call>
|
||||||
|
If you think the survey is finished, please call:
|
||||||
|
<tool_call> {"name": "finalize", "arguments": {} </tool_call>
|
||||||
|
|
||||||
|
** Attention **
|
||||||
|
You must use correct JSON format inside <answer>...</answer> and <tool_call>...</tool_call>, otherwise we can't extract the corrent content.
|
||||||
|
|
||||||
|
**Output format**
|
||||||
|
(1) Current Update:
|
||||||
|
<think> How to write paragraphs with citations based on retrieved information to complete the plan? </think>
|
||||||
|
<answer> Please provide your answer here. (JSON format) </answer>
|
||||||
|
(2) Next Plan:
|
||||||
|
<think> Which part of the survey needs to be updated? What information needs to be queried? </think>
|
||||||
|
<tool_call> Please call a tool here. (JSON format) </tool_call>
|
||||||
|
"""
|
||||||
|
|
||||||
|
USER_PROMPT_v0_0424_BUFFER = """Please update the survey depending on the insturctions.
|
||||||
|
**User's Query**
|
||||||
|
<user_query>
|
||||||
|
|
||||||
|
**Current Survey**
|
||||||
|
<init_survey>
|
||||||
|
|
||||||
|
**Current Plan**
|
||||||
|
<think>I need to get enough information.</think>
|
||||||
|
<tool_call>{}</tool_call>
|
||||||
|
|
||||||
|
**Retrieved Information**
|
||||||
|
There is no results.
|
||||||
|
|
||||||
|
Please give your response following the output format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
USER_PROMPT_0415_BUFFER = """Please update the survey depending on the insturctions.
|
||||||
|
**User's Query**
|
||||||
|
<user_query>
|
||||||
|
|
||||||
|
**Current Survey**
|
||||||
|
<current_survey>
|
||||||
|
|
||||||
|
**Current Plan**
|
||||||
|
<think> <last_step_thought> </think>
|
||||||
|
<tool_call> <last_step_tool_call> </tool_call>
|
||||||
|
|
||||||
|
**Retrieved Information**
|
||||||
|
<summarys>
|
||||||
|
|
||||||
|
Please give your response following the output format.
|
||||||
|
"""
|
||||||
338
code/src/generation/run.py
Normal file
338
code/src/generation/run.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
from contextlib import contextmanager
|
||||||
|
from codetiming import Timer
|
||||||
|
@contextmanager
|
||||||
|
def _timer(name: str, timing_raw):
|
||||||
|
with Timer(name=name, logger=None) as timer:
|
||||||
|
yield
|
||||||
|
timing_raw[name] = timer.last
|
||||||
|
|
||||||
|
from buffer import SurveyManager
|
||||||
|
from buffer import BufferManager_V2 as BufferManager
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import re
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import json
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# 允许跨域(如果前端和后端端口不同需要加上)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
active_connections = set()
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
await websocket.accept()
|
||||||
|
active_connections.add(websocket)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await websocket.receive_text()
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
active_connections.remove(websocket)
|
||||||
|
|
||||||
|
async def post_to_frontend(payload):
|
||||||
|
print(f"Sending payload to frontend: {payload}") # Log the payload being sent
|
||||||
|
for ws in list(active_connections):
|
||||||
|
try:
|
||||||
|
await ws.send_text(payload)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error sending to WebSocket: {e}")
|
||||||
|
active_connections.remove(ws)
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_json(data, path):
|
||||||
|
with open(path, 'w', encoding='utf8') as f:
|
||||||
|
f.write(json.dumps(data, ensure_ascii=False, indent=4))
|
||||||
|
|
||||||
|
class OriginalvLLMRollout:
|
||||||
|
def __init__(self, model_name_or_path):
|
||||||
|
# init vLLM
|
||||||
|
self.rollout_model = LLM(
|
||||||
|
model=model_name_or_path,
|
||||||
|
tokenizer=model_name_or_path,
|
||||||
|
gpu_memory_utilization=0.95,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
self.sampling_params = SamplingParams(
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.8,
|
||||||
|
repetition_penalty=1.05,
|
||||||
|
top_k=20,
|
||||||
|
max_tokens=2748,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(self, input_texts):
|
||||||
|
generated_texts = []
|
||||||
|
completions = self.rollout_model.generate(input_texts, self.sampling_params, use_tqdm=False)
|
||||||
|
for output in completions:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
def chat(self, input_messages):
|
||||||
|
generated_texts = []
|
||||||
|
completions = self.rollout_model.chat(input_messages, self.sampling_params, use_tqdm=False)
|
||||||
|
for output in completions:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
generated_texts.append(generated_text)
|
||||||
|
return generated_texts
|
||||||
|
|
||||||
|
async def rollout_with_env(querys, batch_size, max_turns, model_path, url,
|
||||||
|
deploy_port=None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
querys: [string]
|
||||||
|
"""
|
||||||
|
###############################
|
||||||
|
#### splited by batch size ####
|
||||||
|
###############################
|
||||||
|
n = len(querys) // batch_size
|
||||||
|
batch_querys = []
|
||||||
|
for i in range(n+1):
|
||||||
|
temp_data = querys[i*batch_size: (i+1)*batch_size]
|
||||||
|
if len(temp_data) > 0:
|
||||||
|
batch_querys.append(temp_data)
|
||||||
|
print("QUERY NUMBER with BATCH: ", [len(x) for x in batch_querys])
|
||||||
|
|
||||||
|
###################
|
||||||
|
#### init vllm ####
|
||||||
|
###################
|
||||||
|
vllm_manager = OriginalvLLMRollout(model_path)
|
||||||
|
|
||||||
|
############################
|
||||||
|
#### init Format Reward ####
|
||||||
|
############################
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
|
|
||||||
|
total_rollout_data = []
|
||||||
|
for querys in batch_querys:
|
||||||
|
###########################################
|
||||||
|
#### acquire env configs and init envs ####
|
||||||
|
###########################################
|
||||||
|
buffer_manager = BufferManager(querys)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Break at max-turns
|
||||||
|
if buffer_manager.step >= max_turns:
|
||||||
|
break
|
||||||
|
|
||||||
|
###############################
|
||||||
|
#### prepare input prompts ####
|
||||||
|
###############################
|
||||||
|
messagess_todo = buffer_manager.build_prompt_for_generator()
|
||||||
|
# breakpoint()
|
||||||
|
|
||||||
|
# Break when no tasks
|
||||||
|
if len(messagess_todo) == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
##########################
|
||||||
|
#### generate by vLLM ####
|
||||||
|
##########################
|
||||||
|
timing_raw = {}
|
||||||
|
with _timer('vllm sampling', timing_raw):
|
||||||
|
# response_texts = vllm_manager.chat(messagess_todo)
|
||||||
|
response_texts = await asyncio.to_thread(vllm_manager.chat, messagess_todo)
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#### preprocess the responses ####
|
||||||
|
##################################
|
||||||
|
# 对response的详细处理可以集成到环境类中,因环境而异, 先对Response进行预处理
|
||||||
|
extracted_results = []
|
||||||
|
for response_text in response_texts:
|
||||||
|
result = BufferManager.parse_generator_response(response_text)
|
||||||
|
extracted_results.append(result)
|
||||||
|
|
||||||
|
#################################################
|
||||||
|
#### execute in environment and get feedback ####
|
||||||
|
#################################################
|
||||||
|
payload = {
|
||||||
|
"tool_calls": [x["tool_call"] for x in extracted_results]
|
||||||
|
}
|
||||||
|
if buffer_manager.step <=2:
|
||||||
|
payload["topk"] = 20
|
||||||
|
with _timer('get env feedback', timing_raw):
|
||||||
|
# env_response_batched = requests.post(url, json=payload).json()
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(url, json=payload) as resp:
|
||||||
|
env_response_batched = await resp.json()
|
||||||
|
|
||||||
|
###################################
|
||||||
|
#### postprocess the feedbacks ####
|
||||||
|
###################################
|
||||||
|
with _timer('postprocessing', timing_raw):
|
||||||
|
buffer_manager.update_trajectory(extracted_results, env_response_batched)
|
||||||
|
buffer_manager.step += 1
|
||||||
|
|
||||||
|
print(timing_raw)
|
||||||
|
|
||||||
|
if deploy_port is not None:
|
||||||
|
now_text = json_to_markdown(buffer_manager.batch_rollout_data[-1])
|
||||||
|
now_search_keywords= buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["search_keywords"]
|
||||||
|
now_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["answer_thought"]
|
||||||
|
next_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["tool_call_thought"]
|
||||||
|
now_query = buffer_manager.batch_rollout_data[-1]["query"]
|
||||||
|
trajs = buffer_manager.batch_rollout_data[-1]["trajectory"]
|
||||||
|
updated_success = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["update_success"]
|
||||||
|
if updated_success:
|
||||||
|
for traj in reversed(trajs):
|
||||||
|
if len(traj["summarys"]) > 0:
|
||||||
|
break
|
||||||
|
summary_num = len(traj["summarys"])
|
||||||
|
if summary_num == 0:
|
||||||
|
summary_text = "No summaries yet."
|
||||||
|
else:
|
||||||
|
summary_text = "\n".join(traj["summarys"])
|
||||||
|
frontend_payload = {
|
||||||
|
"markdown": now_text,
|
||||||
|
"searchKeywords": now_search_keywords,
|
||||||
|
"nowUpdate": now_update,
|
||||||
|
"nextUpdate": next_update,
|
||||||
|
"query": now_query,
|
||||||
|
"papers": summary_text
|
||||||
|
}
|
||||||
|
frontend_payload = json.dumps(frontend_payload, ensure_ascii=False)
|
||||||
|
try:
|
||||||
|
await post_to_frontend(frontend_payload)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error posting to frontend: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for item in buffer_manager.batch_rollout_data:
|
||||||
|
item["survey_text"] = SurveyManager.convert_survey_dict_to_str(item["state"]["current_survey"])
|
||||||
|
|
||||||
|
total_rollout_data.extend(buffer_manager.batch_rollout_data)
|
||||||
|
#####################################
|
||||||
|
#### clear all envs and shutdown ####
|
||||||
|
#####################################
|
||||||
|
del buffer_manager
|
||||||
|
|
||||||
|
return total_rollout_data
|
||||||
|
|
||||||
|
|
||||||
|
def json_to_markdown(json_data):
|
||||||
|
text = SurveyManager.convert_survey_dict_to_str(json_data["state"]["current_survey"])
|
||||||
|
all_summarys = {}
|
||||||
|
for traj in json_data["trajectory"]:
|
||||||
|
for item in traj["summarys"]:
|
||||||
|
split_text = item.split("\n")
|
||||||
|
bibkey = split_text[0].split(":")[1].strip()
|
||||||
|
title_begin_index = item.find("Title:") + len("Title:")
|
||||||
|
title_end_index = item.find("Abstract:")
|
||||||
|
title = item[title_begin_index:title_end_index].strip()
|
||||||
|
arxivid = bibkey.split("arxivid")[-1].strip()
|
||||||
|
html = f"arxiv.org/abs/{arxivid}"
|
||||||
|
all_summarys[bibkey] = f"[{title}](https://{html})"
|
||||||
|
|
||||||
|
reg = r"\\cite\{(.+?)\}"
|
||||||
|
placeholder_reg = re.compile(r"^#\d+$")
|
||||||
|
reg_bibkeys = re.findall(reg, text)
|
||||||
|
bibkeys = []
|
||||||
|
for bibkey in reg_bibkeys:
|
||||||
|
single_bib = bibkey.split(",")
|
||||||
|
for bib in single_bib:
|
||||||
|
if not placeholder_reg.match(bib):
|
||||||
|
bib = bib.strip()
|
||||||
|
if bib and bib != "*" and bib not in bibkeys:
|
||||||
|
bibkeys.append(bib)
|
||||||
|
|
||||||
|
bibkeys_index = {bibkey: i+1 for i, bibkey in enumerate(bibkeys)}
|
||||||
|
|
||||||
|
def replace_bibkey(bibkey):
|
||||||
|
bibkey = bibkey.group(1)
|
||||||
|
single_bib = bibkey.split(",")
|
||||||
|
new_bibs = []
|
||||||
|
for bib in single_bib:
|
||||||
|
if not placeholder_reg.match(bib):
|
||||||
|
bib = bib.strip()
|
||||||
|
if bib and bib != "*":
|
||||||
|
if bib in bibkeys_index:
|
||||||
|
new_bibs.append(f"{bibkeys_index[bib]}")
|
||||||
|
else:
|
||||||
|
print(f"Warning: {bib} not found in bibkeys")
|
||||||
|
if len(new_bibs) > 0:
|
||||||
|
return "[" + ",".join(new_bibs) + "]"
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
text = re.sub(reg, replace_bibkey, text)
|
||||||
|
reference_text = "\n\n".join([f"[{i}] {all_summarys[bibkey]}" for bibkey, i in bibkeys_index.items()])
|
||||||
|
text += "\n## References\n" + reference_text
|
||||||
|
return text
|
||||||
|
|
||||||
|
async def test_surveyGen(model_path, out_path,querys, url, deploy_port=None):
|
||||||
|
|
||||||
|
total_rollout_data = await rollout_with_env(querys, 1, 1000, model_path, url, deploy_port)
|
||||||
|
all_md_texts = []
|
||||||
|
for json_data in total_rollout_data:
|
||||||
|
md_text = json_to_markdown(json_data)
|
||||||
|
all_md_texts.append(md_text)
|
||||||
|
|
||||||
|
all_md_texts = "\n\n".join(all_md_texts)
|
||||||
|
with open(out_path, 'w', encoding='utf8') as f:
|
||||||
|
f.write(all_md_texts)
|
||||||
|
|
||||||
|
# with jsonlines.open(out_path, 'w') as writer:
|
||||||
|
# for item in total_rollout_data:
|
||||||
|
# writer.write(item)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
@app.post("/generate_survey")
|
||||||
|
async def generate_survey(request: QueryRequest):
|
||||||
|
global args # Ensure args is accessible
|
||||||
|
# 这里可以根据需要处理查询
|
||||||
|
model_path = args.model_path
|
||||||
|
out_path = args.output_file
|
||||||
|
query = request.query
|
||||||
|
querys = [query] # 将查询转换为列表
|
||||||
|
url = args.retriver_url
|
||||||
|
deploy_port = args.port if args.port is not None else None
|
||||||
|
try:
|
||||||
|
await test_surveyGen(model_path, out_path, querys, url, deploy_port)
|
||||||
|
return {"status": "success", "message": "Survey generated successfully."}
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error generating survey: {e}")
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run survey generation with vLLM.")
|
||||||
|
parser.add_argument("--model_path", type=str, required=True, help="Path to the model.")
|
||||||
|
parser.add_argument("--query", type=str, required=True, help="Query to generate survey.")
|
||||||
|
parser.add_argument("--output_file", type=str, required=True, help="Path to the output Markdown file.")
|
||||||
|
parser.add_argument("--retriver_url", type=str, default="http://localhost:8400", help="URL of the retriever service.")
|
||||||
|
parser.add_argument("--port", type=str, default=None, help="Deploy port, default is None, which means not deploy.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.port is not None:
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="localhost", port=int(args.port))# log_level="debug")
|
||||||
|
|
||||||
|
# Run the survey generation
|
||||||
|
else:
|
||||||
|
asyncio.run(
|
||||||
|
test_surveyGen(
|
||||||
|
model_path=args.model_path,
|
||||||
|
out_path=args.output_file,
|
||||||
|
querys=[args.query],
|
||||||
|
url=args.retriver_url
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
55
code/src/preprocess/build_index.py
Normal file
55
code/src/preprocess/build_index.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import torch.distributed
|
||||||
|
import faiss
|
||||||
|
import pandas as pd
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import jsonlines, json
|
||||||
|
from transformers import AutoModel
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
'''
|
||||||
|
data format:
|
||||||
|
{
|
||||||
|
"bibkey": "some_bibkey",
|
||||||
|
"text": "The abstract or text of the paper."
|
||||||
|
}
|
||||||
|
example:
|
||||||
|
{
|
||||||
|
"bibkey": "arxivid1234.5678",
|
||||||
|
"text": "Title: A Study on Something\nAbstract: This paper discusses the findings of a study on something important in the field of research.\nAuthors: John Doe"
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
|
||||||
|
model_name = "openbmb/MiniCPM-Embedding-Light"
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
|
||||||
|
|
||||||
|
input_path = "./data/arxiv.jsonl"
|
||||||
|
|
||||||
|
with jsonlines.open(input_path) as f:
|
||||||
|
survey_data = list(f)
|
||||||
|
|
||||||
|
|
||||||
|
xids = [item["bibkey"] for item in survey_data]
|
||||||
|
passages = [item["text"] for item in survey_data]
|
||||||
|
|
||||||
|
embeddings_doc_dense, _ = model.encode_corpus(passages, max_length=1024)
|
||||||
|
|
||||||
|
|
||||||
|
# faiss save index
|
||||||
|
index = faiss.IndexFlatIP(embeddings_doc_dense.shape[1])
|
||||||
|
id_map_index = faiss.IndexIDMap(index)
|
||||||
|
index = faiss.index_cpu_to_all_gpus(id_map_index)
|
||||||
|
|
||||||
|
x_ids_int = np.array(np.arange(len(xids)))
|
||||||
|
|
||||||
|
str_int_ids = {}
|
||||||
|
for i in range(len(xids)):
|
||||||
|
str_int_ids[xids[i]] = x_ids_int[i]
|
||||||
|
str_int_ids_df = pd.DataFrame(str_int_ids, index=[0]).T.reset_index()
|
||||||
|
str_int_ids_df.columns = ["str_id", "int_id"]
|
||||||
|
str_int_ids_df.to_csv("./index/str_int_ids_abstract.csv", index=False)
|
||||||
|
|
||||||
|
index.add_with_ids(embeddings_doc_dense, x_ids_int)
|
||||||
|
|
||||||
|
index = faiss.index_gpu_to_cpu(index)
|
||||||
|
faiss.write_index(index, "./index/index_abstract.faiss")
|
||||||
21
code/src/preprocess/data_process.py
Normal file
21
code/src/preprocess/data_process.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# curl -L -o ~/Downloads/arxiv.zip\
|
||||||
|
# https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
|
||||||
|
|
||||||
|
|
||||||
|
import jsonlines
|
||||||
|
|
||||||
|
input_path = './data/arxiv-metadata-oai-snapshot.json'
|
||||||
|
output_path = './data/arxiv.jsonl'
|
||||||
|
|
||||||
|
new_data = []
|
||||||
|
with jsonlines.open(input_path, 'r') as reader:
|
||||||
|
for item in reader:
|
||||||
|
new_item = {
|
||||||
|
'bibkey': f"arxivid{item['id']}",
|
||||||
|
'text': f"Title: {item['title']}\nAbstract: {item['abstract']}\nAuthors: {item['authors']}",
|
||||||
|
}
|
||||||
|
new_data.append(new_item)
|
||||||
|
|
||||||
|
with jsonlines.open(output_path, 'w') as writer:
|
||||||
|
for item in new_data:
|
||||||
|
writer.write(item)
|
||||||
175
code/src/retriever/retriever.py
Normal file
175
code/src/retriever/retriever.py
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
import faiss
|
||||||
|
from fastapi import FastAPI
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
from collections import defaultdict
|
||||||
|
import pandas as pd
|
||||||
|
import jsonlines
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
import uvicorn
|
||||||
|
import asyncio
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Optional
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
model_name = "openbmb/MiniCPM-Embedding-Light"
|
||||||
|
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
|
||||||
|
co = faiss.GpuMultipleClonerOptions()
|
||||||
|
co.shard = True
|
||||||
|
co.useFloat16 = True
|
||||||
|
|
||||||
|
faiss_index_path = "./index/index_abstract.faiss" # Replace with your FAISS index path"
|
||||||
|
faiss_index = faiss.read_index(faiss_index_path)
|
||||||
|
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index,co=co)
|
||||||
|
|
||||||
|
corpus_path = "./data/arxiv.jsonl"
|
||||||
|
with jsonlines.open(corpus_path) as f:
|
||||||
|
paper_data = list(f)
|
||||||
|
paper_dict = {}
|
||||||
|
item_key = "text"
|
||||||
|
|
||||||
|
|
||||||
|
index_path = "./index/str_int_ids_abstract.csv"
|
||||||
|
index_df = pd.read_csv(index_path,converters={0: lambda x: str(x),1: lambda x: int(x)})
|
||||||
|
index_df.columns = ["str_id", "int_id"]
|
||||||
|
index_dict = index_df.set_index("int_id")["str_id"].to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
for item in paper_data:
|
||||||
|
paper_dict[item["bibkey"]] = item[item_key]
|
||||||
|
|
||||||
|
class QueryRequest(BaseModel):
|
||||||
|
queries: List[str]
|
||||||
|
topk: Optional[int] = None
|
||||||
|
return_scores: bool = False
|
||||||
|
|
||||||
|
class MessageRequest(BaseModel):
|
||||||
|
tool_calls: List
|
||||||
|
topk: Optional[int] = 10
|
||||||
|
|
||||||
|
@app.post("/")
|
||||||
|
async def search_text_batch(request:MessageRequest):
|
||||||
|
tool_calls = request.tool_calls
|
||||||
|
topk = request.topk
|
||||||
|
results = []
|
||||||
|
finalize_indices = []
|
||||||
|
search_engine_indices = []
|
||||||
|
for i in range(len(tool_calls)):
|
||||||
|
try:
|
||||||
|
tool_calls[i]["name"]
|
||||||
|
except KeyError:
|
||||||
|
finalize_indices.append(i)
|
||||||
|
continue
|
||||||
|
if tool_calls[i]["name"] == "search_engine":
|
||||||
|
search_engine_indices.append(i)
|
||||||
|
elif tool_calls[i]["name"] == "finalize":
|
||||||
|
finalize_indices.append(i)
|
||||||
|
else:
|
||||||
|
finalize_indices.append(i)
|
||||||
|
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
for i in range(len(tool_calls)):
|
||||||
|
if i in search_engine_indices:
|
||||||
|
tasks.append(call_search_engine(tool_calls[i], topk))
|
||||||
|
search_task_results = await asyncio.gather(*tasks)
|
||||||
|
num_search = 0
|
||||||
|
num_finalize = 0
|
||||||
|
for i in range(len(tool_calls)):
|
||||||
|
if i in finalize_indices:
|
||||||
|
search_keywords, bibkeys,abstracts, done, score = "",[], [], True, 0.0
|
||||||
|
num_finalize += 1
|
||||||
|
elif i in search_engine_indices:
|
||||||
|
search_keywords, bibkeys, abstracts, done, score = search_task_results[num_search]
|
||||||
|
num_search += 1
|
||||||
|
|
||||||
|
titles = []
|
||||||
|
for abstract in abstracts:
|
||||||
|
try:
|
||||||
|
title = abstract.split("\n")[1]
|
||||||
|
title = title.split(":")[1].strip()
|
||||||
|
titles.append(title)
|
||||||
|
except:
|
||||||
|
titles.append("")
|
||||||
|
results.append({ "search_keywords":search_keywords, "summarys":abstracts, "done":done, "score":score, "titles":titles, "bibkeys":bibkeys})
|
||||||
|
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def extract_tool_call(text: str):
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
tool_text = match.group(1)
|
||||||
|
try:
|
||||||
|
tool_call = json.loads(tool_text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return tool_call if isinstance(tool_call, dict) else None
|
||||||
|
|
||||||
|
|
||||||
|
def get_response(queries,ref):
|
||||||
|
text_raw = paper_dict[str(ref)]
|
||||||
|
text_raw = tokenizer(text_raw, max_length=8192, truncation=True)
|
||||||
|
text_raw = tokenizer.decode(text_raw["input_ids"])
|
||||||
|
|
||||||
|
response = text_raw
|
||||||
|
response = f"bibkey: {str(ref)}\n"+response
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def call_search_engine(tool_call, topk=10):
|
||||||
|
try:
|
||||||
|
queries = tool_call["arguments"]["query"]
|
||||||
|
if isinstance(queries, str):
|
||||||
|
queries = [queries]
|
||||||
|
else:
|
||||||
|
queries = list(queries)
|
||||||
|
|
||||||
|
if len(queries) == 0:
|
||||||
|
return "", [], [], False, 0.0
|
||||||
|
results = defaultdict(dict)
|
||||||
|
query_embedding_to_text,_ = model.encode_query(queries, max_length=512, show_progress_bar=False)
|
||||||
|
_,results = faiss_index.search(query_embedding_to_text, topk)
|
||||||
|
result2query = {}
|
||||||
|
merge_rrf = defaultdict(float)
|
||||||
|
for i in range(len(results)):
|
||||||
|
for j in range(len(results[i])):
|
||||||
|
merge_rrf[results[i][j]] += 1/(j+1)
|
||||||
|
result2query[results[i][j]] = queries[i]
|
||||||
|
results = sorted(merge_rrf.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
results = [x[0] for x in results][:topk]
|
||||||
|
|
||||||
|
# new_queries = [result2query[result] for result in results]
|
||||||
|
queries = ",".join(queries)
|
||||||
|
|
||||||
|
# bibkeys = [str(results[i]) for i in range(len(results))]
|
||||||
|
bibkeys = [str(index_dict[results[i]]) for i in range(len(results))]
|
||||||
|
response = [f"bibkey: {bibkey}\n{paper_dict[bibkey]}" for bibkey in bibkeys]
|
||||||
|
return queries,bibkeys , response, False, 0.0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error in call_search_engine: {e}")
|
||||||
|
return "",[], [], False, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Run the FastAPI application.")
|
||||||
|
parser.add_argument("--port", type=int, default=8400, help="Port to run the FastAPI application on.")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||||
42
config.json
Normal file
42
config.json
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
{
|
||||||
|
"_name_or_path": "openbmb/MiniCPM4-8B",
|
||||||
|
"architectures": [
|
||||||
|
"MiniCPMForCausalLM"
|
||||||
|
],
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_minicpm.MiniCPMConfig",
|
||||||
|
"AutoModel": "modeling_minicpm.MiniCPMModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_minicpm.MiniCPMForCausalLM",
|
||||||
|
"AutoModelForSeq2SeqLM": "modeling_minicpm.MiniCPMForCausalLM",
|
||||||
|
"AutoModelForSequenceClassification": "modeling_minicpm.MiniCPMForSequenceClassification"
|
||||||
|
},
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"eos_token_id": [2, 73440],
|
||||||
|
"pad_token_id": 2,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"initializer_range": 0.1,
|
||||||
|
"intermediate_size": 16384,
|
||||||
|
"max_position_embeddings": 32768,
|
||||||
|
"model_type": "minicpm",
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"rope_type": "longrope",
|
||||||
|
"long_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.615569542115128, 5.2684819496549835, 6.014438591970396, 6.858830049237097, 7.804668263503327, 8.851768731513417, 9.99600492938444, 11.228766118181639, 12.536757560834843, 13.902257701387796, 15.303885189125953, 16.717837610115794, 18.119465097853947, 19.484965238406907, 20.792956681060105, 22.02571786985731, 23.16995406772833, 24.217054535738416, 25.16289275000465, 26.007284207271347, 26.753240849586767, 27.40615325712662, 27.973003419175363, 28.461674954469114, 28.880393889607006, 29.237306864684626, 29.540186419591297, 29.79624387177199, 30.01202719065413, 30.193382037992453, 30.34545697551969, 30.47273746338473, 30.579096895249787, 30.66785612408345, 30.741845563814174, 30.80346599254902, 30.85474569563567, 30.897392663720595, 30.932841297560394, 30.962293553185553, 30.986754758742034, 31.007064503249293, 31.02392307921529],
|
||||||
|
"short_factor": [0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, 4.048719380066383, 4.615569542115128, 5.2684819496549835, 6.014438591970396, 6.858830049237097, 7.804668263503327, 8.851768731513417, 9.99600492938444, 11.228766118181639, 12.536757560834843, 13.902257701387796, 15.303885189125953, 16.717837610115794, 18.119465097853947, 19.484965238406907, 20.792956681060105, 22.02571786985731, 23.16995406772833, 24.217054535738416, 25.16289275000465, 26.007284207271347, 26.753240849586767, 27.40615325712662, 27.973003419175363, 28.461674954469114, 28.880393889607006, 29.237306864684626, 29.540186419591297, 29.79624387177199, 30.01202719065413, 30.193382037992453, 30.34545697551969, 30.47273746338473, 30.579096895249787, 30.66785612408345, 30.741845563814174, 30.80346599254902, 30.85474569563567, 30.897392663720595, 30.932841297560394, 30.962293553185553, 30.986754758742034, 31.007064503249293, 31.02392307921529],
|
||||||
|
"original_max_position_embeddings": 32768
|
||||||
|
},
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.46.3",
|
||||||
|
"use_cache": true,
|
||||||
|
"vocab_size": 73448,
|
||||||
|
"rope_theta": 10000.0,
|
||||||
|
"scale_emb": 12,
|
||||||
|
"scale_depth": 1.4,
|
||||||
|
"mup_denominator": 32,
|
||||||
|
"dim_model_base": 256,
|
||||||
|
"tie_word_embeddings": false
|
||||||
|
}
|
||||||
1
configuration.json
Normal file
1
configuration.json
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{"framework":"Pytorch","task":"text-generation"}
|
||||||
203
configuration_minicpm.py
Normal file
203
configuration_minicpm.py
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2025 The OpenBMB Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
""" MiniCPM model configuration"""
|
||||||
|
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
MINICPM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
||||||
|
|
||||||
|
|
||||||
|
class MiniCPMConfig(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`MiniCPMModel`]. It is used to instantiate an MiniCPM
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the MiniCPM-7B.
|
||||||
|
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 32000):
|
||||||
|
Vocabulary size of the MiniCPM model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`MiniCPMModel`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with. MiniCPM 1 supports up to 2048 tokens,
|
||||||
|
MiniCPM 2 up to 4096, CodeMiniCPM up to 16384.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
pretraining_tp (`int`, *optional*, defaults to 1):
|
||||||
|
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
||||||
|
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
||||||
|
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
||||||
|
issue](https://github.com/pytorch/pytorch/issues/76232).
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||||
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
||||||
|
these scaling strategies behave:
|
||||||
|
https://www.reddit.com/r/LocalMiniCPM/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
||||||
|
experimental feature, subject to breaking API changes in future versions.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import MiniCPMModel, MiniCPMConfig
|
||||||
|
|
||||||
|
>>> # Initializing a MiniCPM minicpm-7b style configuration
|
||||||
|
>>> configuration = MiniCPMConfig()
|
||||||
|
|
||||||
|
>>> # Initializing a model from the minicpm-7b style configuration
|
||||||
|
>>> model = MiniCPMModel(configuration)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = 'minicpm'
|
||||||
|
keys_to_ignore_at_inference = ['past_key_values']
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=32000,
|
||||||
|
hidden_size=4096,
|
||||||
|
intermediate_size=11008,
|
||||||
|
num_hidden_layers=32,
|
||||||
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
|
hidden_act='silu',
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=1,
|
||||||
|
eos_token_id=2,
|
||||||
|
pretraining_tp=1,
|
||||||
|
tie_word_embeddings=True,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
scale_emb=1,
|
||||||
|
dim_model_base=1,
|
||||||
|
scale_depth=1,
|
||||||
|
mup_denominator=32,
|
||||||
|
sparse_config=None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.pretraining_tp = pretraining_tp
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
# self._rope_scaling_validation()
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
self.scale_emb = scale_emb
|
||||||
|
self.dim_model_base = dim_model_base
|
||||||
|
self.scale_depth = scale_depth
|
||||||
|
# only used for Eagle Head
|
||||||
|
self.mup_denominator = mup_denominator
|
||||||
|
|
||||||
|
# sparse config
|
||||||
|
self.sparse_config = sparse_config
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
import flash_attn
|
||||||
|
self._attn_implementation = 'flash_attention_2'
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
'`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, '
|
||||||
|
f'got {self.rope_scaling}'
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get('type', None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get('factor', None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ['linear', 'dynamic']:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
||||||
12
generation_config.json
Normal file
12
generation_config.json
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
{
|
||||||
|
"bos_token_id": 1,
|
||||||
|
"do_sample": true,
|
||||||
|
"eos_token_id": [
|
||||||
|
2,
|
||||||
|
73440
|
||||||
|
],
|
||||||
|
"pad_token_id": 2,
|
||||||
|
"temperature": 0.8,
|
||||||
|
"top_p": 0.8,
|
||||||
|
"transformers_version": "4.46.1"
|
||||||
|
}
|
||||||
3
model-00001-of-00004.safetensors
Normal file
3
model-00001-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:2272d2c328bc65ef4e1a7710f1d29f3584c84fe7f5a8460a44e41da41561ac72
|
||||||
|
size 4964116352
|
||||||
3
model-00002-of-00004.safetensors
Normal file
3
model-00002-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:927aa21158b149c9bd8de07238d2060fba7525f822fa465b6a74b2ccb25e1cfb
|
||||||
|
size 4974396952
|
||||||
3
model-00003-of-00004.safetensors
Normal file
3
model-00003-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:74a22956fac386f8f1fdff60c090aa9a14d2d241b666261ad61422ce5def2d8f
|
||||||
|
size 4877986360
|
||||||
3
model-00004-of-00004.safetensors
Normal file
3
model-00004-of-00004.safetensors
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:b3d1610f229731974c82b387dfaa04e74920ae45f6f892aadfaef0f65c160066
|
||||||
|
size 1554041896
|
||||||
298
model.safetensors.index.json
Normal file
298
model.safetensors.index.json
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
{
|
||||||
|
"metadata": {
|
||||||
|
"total_size": 16370507776
|
||||||
|
},
|
||||||
|
"weight_map": {
|
||||||
|
"lm_head.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.embed_tokens.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.0.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.0.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.0.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.0.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.0.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.0.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.0.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.1.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.1.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.10.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.10.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.10.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.10.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.10.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.10.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.10.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.10.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.11.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.11.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.11.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.11.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.11.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.11.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.11.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.11.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.12.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.12.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.12.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.12.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.12.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.12.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.12.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.12.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.12.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.13.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.13.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.13.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.13.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.13.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.13.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.13.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.14.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.14.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.14.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.14.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.14.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.15.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.15.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.15.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.15.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.15.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.15.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.15.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.15.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.15.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.16.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.16.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.16.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.16.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.16.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.16.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.16.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.17.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.17.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.17.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.17.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.17.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.17.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.17.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.17.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.18.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.18.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.18.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.18.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.18.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.18.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.18.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.19.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.19.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.19.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.19.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.19.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.19.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.2.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.2.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.2.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.2.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.2.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.2.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.2.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.2.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.20.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.20.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.20.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.20.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.20.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.20.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.20.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.20.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.21.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.21.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.21.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.21.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.21.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.21.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.21.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.22.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.22.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.22.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.22.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.22.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.22.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.22.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.23.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.23.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.23.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.23.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.23.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.24.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.24.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.24.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.24.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.24.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.24.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.24.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.25.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.25.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.25.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.25.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.25.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.26.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.26.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.26.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.26.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.26.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.26.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.26.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.26.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.27.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.27.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.27.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.27.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.27.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.27.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.27.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.28.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.28.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.29.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.29.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.29.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.29.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.29.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.29.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.29.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.3.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.3.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.3.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.3.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.3.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.3.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.3.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.3.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.30.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.30.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.30.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.30.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.31.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.31.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.31.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.31.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.31.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.31.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.31.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.4.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.4.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.4.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.4.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.4.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.4.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.4.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.5.input_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.5.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.5.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.5.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.5.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.5.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.5.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.6.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.6.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.6.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.6.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.6.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.6.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.7.input_layernorm.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.7.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.7.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.7.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.7.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.7.self_attn.q_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.7.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.8.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.8.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.8.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.8.self_attn.k_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.8.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.8.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.8.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.9.input_layernorm.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.9.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.9.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.layers.9.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.9.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
|
||||||
|
"model.layers.9.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
|
||||||
|
"model.layers.9.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
|
||||||
|
"model.layers.9.self_attn.v_proj.weight": "model-00001-of-00004.safetensors",
|
||||||
|
"model.norm.weight": "model-00002-of-00004.safetensors"
|
||||||
|
}
|
||||||
|
}
|
||||||
2509
modeling_minicpm.py
Normal file
2509
modeling_minicpm.py
Normal file
File diff suppressed because it is too large
Load Diff
33
special_tokens_map.json
Normal file
33
special_tokens_map.json
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
{
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|tool_call|>",
|
||||||
|
"<|execute_start|>",
|
||||||
|
"<|execute_end|>",
|
||||||
|
"<|fim_prefix|>",
|
||||||
|
"<|fim_middle|>",
|
||||||
|
"<|fim_suffix|>"
|
||||||
|
],
|
||||||
|
"bos_token": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"eos_token": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
},
|
||||||
|
"unk_token": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false
|
||||||
|
}
|
||||||
|
}
|
||||||
3
tokenizer.json
Normal file
3
tokenizer.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:adf7208af154a5ca065d2eda4e5419e02aac58c2c00627874748b75ec6769094
|
||||||
|
size 6701371
|
||||||
3
tokenizer.model
Normal file
3
tokenizer.model
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
version https://git-lfs.github.com/spec/v1
|
||||||
|
oid sha256:bb74d51116831c3bf65db812c553f94ab0c88dcf97a5bbb37e3504f6d359c530
|
||||||
|
size 1181204
|
||||||
117
tokenizer_config.json
Normal file
117
tokenizer_config.json
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"add_prefix_space": null,
|
||||||
|
"added_tokens_decoder": {
|
||||||
|
"0": {
|
||||||
|
"content": "<unk>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"1": {
|
||||||
|
"content": "<s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"2": {
|
||||||
|
"content": "</s>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73440": {
|
||||||
|
"content": "<|im_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73441": {
|
||||||
|
"content": "<|im_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73442": {
|
||||||
|
"content": "<|tool_call|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73443": {
|
||||||
|
"content": "<|execute_start|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73444": {
|
||||||
|
"content": "<|execute_end|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73445": {
|
||||||
|
"content": "<|fim_prefix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73446": {
|
||||||
|
"content": "<|fim_middle|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
"73447": {
|
||||||
|
"content": "<|fim_suffix|>",
|
||||||
|
"lstrip": false,
|
||||||
|
"normalized": false,
|
||||||
|
"rstrip": false,
|
||||||
|
"single_word": false,
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additional_special_tokens": [
|
||||||
|
"<|im_end|>",
|
||||||
|
"<|im_start|>",
|
||||||
|
"<|tool_call|>",
|
||||||
|
"<|execute_start|>",
|
||||||
|
"<|execute_end|>",
|
||||||
|
"<|fim_prefix|>",
|
||||||
|
"<|fim_middle|>",
|
||||||
|
"<|fim_suffix|>"
|
||||||
|
],
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
"clean_up_tokenization_spaces": false,
|
||||||
|
"eos_token": "<|im_end|>",
|
||||||
|
"legacy": true,
|
||||||
|
"model_max_length": 1000000000000000019884624838656,
|
||||||
|
"pad_token": null,
|
||||||
|
"sp_model_kwargs": {},
|
||||||
|
"spaces_between_special_tokens": false,
|
||||||
|
"tokenizer_class": "LlamaTokenizer",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"use_default_system_prompt": false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user