初始化项目,由ModelHub XC社区提供模型

Model: OpenBMB/MiniCPM4-Survey
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-07 14:21:15 +08:00
commit adbc52087b
37 changed files with 8558 additions and 0 deletions

View File

@@ -0,0 +1,24 @@
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
pnpm-debug.log*
lerna-debug.log*
node_modules
dist
dist-ssr
*.local
# Editor directories and files
.vscode/*
!.vscode/extensions.json
.idea
.DS_Store
*.suo
*.ntvs*
*.njsproj
*.sln
*.sw?

View File

@@ -0,0 +1,33 @@
import js from '@eslint/js'
import globals from 'globals'
import reactHooks from 'eslint-plugin-react-hooks'
import reactRefresh from 'eslint-plugin-react-refresh'
export default [
{ ignores: ['dist'] },
{
files: ['**/*.{js,jsx}'],
languageOptions: {
ecmaVersion: 2020,
globals: globals.browser,
parserOptions: {
ecmaVersion: 'latest',
ecmaFeatures: { jsx: true },
sourceType: 'module',
},
},
plugins: {
'react-hooks': reactHooks,
'react-refresh': reactRefresh,
},
rules: {
...js.configs.recommended.rules,
...reactHooks.configs.recommended.rules,
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
'react-refresh/only-export-components': [
'warn',
{ allowConstantExport: true },
],
},
},
]

View File

@@ -0,0 +1,13 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/png" href="/openbmb.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>MiniCPM4-Survey</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/main.jsx"></script>
</body>
</html>

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,29 @@
{
"name": "minicpm4-survey",
"private": true,
"version": "0.0.0",
"type": "module",
"scripts": {
"dev": "vite",
"build": "vite build",
"lint": "eslint .",
"preview": "vite preview"
},
"dependencies": {
"dompurify": "^3.2.6",
"marked": "^15.0.12",
"react": "^19.1.0",
"react-dom": "^19.1.0"
},
"devDependencies": {
"@eslint/js": "^9.25.0",
"@types/react": "^19.1.2",
"@types/react-dom": "^19.1.2",
"@vitejs/plugin-react": "^4.4.1",
"eslint": "^9.25.0",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-react-refresh": "^0.4.19",
"globals": "^16.0.0",
"vite": "^6.3.5"
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 629 B

View File

@@ -0,0 +1,318 @@
:root {
--neon-blue: #00f3ff;
--neon-purple: #ffffff;
--dark-panel: rgba(10, 10, 30, 0.95);
--glass-color: rgba(255, 255, 255, 0.05);
}
* {
box-sizing: border-box;
margin: 0;
padding: 0;
font-family: 'Orbitron', monospace;
}
body {
overflow: hidden;
}
.cyber-container {
position: relative;
display: flex;
min-height: 100vh;
background: radial-gradient(circle at center, #0a0a1f 0%, #000000 100%);
padding: 2rem;
gap: 2rem;
overflow: hidden;
}
/* 全息背景 */
.hologram-bg {
position: fixed;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: repeating-linear-gradient(
45deg,
transparent,
transparent 5px,
rgba(0, 255, 255, 0.05) 5px,
rgba(0, 255, 255, 0.05) 10px
);
animation: scan 20s linear infinite;
z-index: 0;
}
@keyframes scan {
0% { transform: translateY(-50%) rotate(0deg); }
100% { transform: translateY(50%) rotate(360deg); }
}
/* 面板样式 */
.tech-panel {
position: relative;
width: 220px;
padding: 1.5rem;
background: var(--glass-color);
backdrop-filter: blur(10px);
border: 1px solid rgba(255, 255, 255, 0.1);
border-radius: 12px;
box-shadow: 0 0 20px rgba(0, 255, 255, 0.2);
z-index: 1;
}
.tech-panel::before {
content: '';
position: absolute;
top: 0;
left: 0;
right: 0;
height: 2px;
background: linear-gradient(90deg, var(--neon-blue), var(--neon-purple), var(--neon-blue));
animation: pulse 3s infinite;
}
@keyframes pulse {
0%, 100% { opacity: 0.5; }
50% { opacity: 1; }
}
/* 输入框 */
.input-wrapper {
position: relative;
margin-bottom: 1.2rem;
}
.neon-input {
width: 100%;
padding: 0.8rem 1rem;
background: rgba(255, 255, 255, 0.03);
border: 1px solid rgba(0, 255, 255, 0.3);
border-radius: 6px;
color: #fff;
font-size: 1rem;
outline: none;
transition: all 0.3s ease;
}
.neon-input::placeholder {
color: rgba(255, 255, 255, 0.3);
}
.neon-input:focus {
border-color: var(--neon-blue);
box-shadow: 0 0 10px var(--neon-blue);
}
.input-glow {
position: absolute;
bottom: -5px;
left: 0;
width: 100%;
height: 2px;
background: linear-gradient(90deg, transparent, var(--neon-purple), transparent);
animation: glowPulse 2s infinite;
}
@keyframes glowPulse {
0%, 100% { opacity: 0; }
50% { opacity: 1; }
}
/* 核心区域 */
.core-module {
position: relative;
flex: 1;
z-index: 1;
display: flex;
flex-direction: column;
}
.quantum-textarea {
flex: 1;
padding: 2rem;
font-size: 1.2rem;
background: rgba(0, 0, 20, 0.7);
border: 2px solid var(--neon-blue);
border-radius: 12px;
color: #fff;
outline: none;
resize: both;
min-height: 300px;
backdrop-filter: blur(5px);
font-family: 'Orbitron', monospace;
transition: all 0.3s ease;
}
.quantum-textarea::placeholder {
color: rgba(255, 255, 255, 0.2);
}
.quantum-textarea:focus {
border-color: var(--neon-purple);
box-shadow: 0 0 20px var(--neon-purple);
}
.core-glow {
position: absolute;
top: 50%;
left: 50%;
width: 300%;
height: 300%;
background: radial-gradient(circle, var(--neon-blue) 0%, transparent 70%);
opacity: 0.1;
transform: translate(-50%, -50%);
z-index: 0;
pointer-events: none;
animation: pulseGlow 5s infinite alternate;
}
@keyframes pulseGlow {
0% { transform: translate(-50%, -50%) scale(1); }
100% { transform: translate(-50%, -50%) scale(1.2); }
}
/* 数据流动画 */
.data-stream {
position: absolute;
bottom: 1rem;
right: 2rem;
display: flex;
gap: 0.5rem;
z-index: 2;
}
.stream-pulse {
width: 6px;
height: 6px;
background: var(--neon-blue);
border-radius: 50%;
animation: pulseDot 1.5s infinite;
}
.delay-1 {
animation-delay: 0.3s;
}
.delay-2 {
animation-delay: 0.6s;
}
@keyframes pulseDot {
0% { transform: scale(1); opacity: 1; }
70% { transform: scale(1.5); opacity: 0.3; }
100% { transform: scale(1); opacity: 1; }
}
/* 响应式设计 */
@media (max-width: 1024px) {
.cyber-container {
flex-direction: column;
padding: 1rem;
}
.tech-panel {
width: 100%;
margin-bottom: 1.5rem;
}
.core-module {
min-height: 400px;
}
}
.markdown-editor {
display: flex;
gap: 20px;
height: 100%;
z-index: 1;
}
.markdown-input,
.markdown-preview {
flex: 1;
padding: 15px;
font-size: 16px;
border: 2px solid var(--neon-blue);
border-radius: 8px;
background: rgba(0, 0, 20, 0.7);
color: #fff;
font-family: 'Orbitron', monospace;
resize: both;
}
.markdown-input {
min-height: 300px;
outline: none;
}
.markdown-input:focus {
box-shadow: 0 0 10px var(--neon-blue);
}
.markdown-preview {
max-height: 800px; /* 设置最大高度,超出后滚动 */
width: 100%;
overflow-y: auto; /* 垂直方向溢出时显示滚动条 */
padding: 10px;
border: 1px solid #333;
background-color: #111;
color: #eee;
font-family: monospace;
}
/* Markdown 内容增强样式 */
.markdown-preview h1, h2, h3 {
color: var(--neon-purple);
}
.markdown-preview pre {
background: #111;
padding: 10px;
border-radius: 6px;
color: #00ffcc;
overflow-x: auto;
}
.markdown-preview code {
background: rgba(255, 255, 255, 0.1);
padding: 2px 4px;
border-radius: 4px;
color: var(--neon-blue);
}
.core-module {
flex: 1;
display: flex;
flex-direction: column;
gap: 1rem;
padding: 1rem;
}
.markdown-toolbar {
display: flex;
gap: 10px;
margin-bottom: 10px;
}
.neon-button {
background-color: #000;
color: #0f0;
border: 2px solid #0f0;
padding: 6px 12px;
font-size: 14px;
cursor: pointer;
transition: all 0.2s ease-in-out;
}
.neon-button:hover {
background-color: #0f0;
color: #000;
}
.markdown-editor {
display: flex;
flex-direction: column;
gap: 10px;
padding: 1rem;
}

View File

@@ -0,0 +1,259 @@
import React, { useState, useEffect, useMemo, useRef } from 'react';
import './App.css';
import DOMPurify from 'dompurify';
import { marked } from 'marked';
// 自定义 hook防抖
function useDebounce(value, delay) {
const [debouncedValue, setDebouncedValue] = useState(value);
React.useEffect(() => {
const handler = setTimeout(() => {
setDebouncedValue(value);
}, delay);
return () => clearTimeout(handler);
}, [value, delay]);
return debouncedValue;
}
function MarkdownEditor({ value }) {
const containerRef = useRef(null);
const htmlContent = marked(value || '');
const sanitizedHtml = DOMPurify.sanitize(htmlContent);
const [userScrolled, setUserScrolled] = useState(false);
useEffect(() => {
const container = containerRef.current;
if (container && !userScrolled) {
requestAnimationFrame(() => {
container.scrollTop = container.scrollHeight;
});
}
}, [value, userScrolled]);
useEffect(() => {
const container = containerRef.current;
if (container) {
const handleScroll = () => {
const atBottom = container.scrollTop + container.clientHeight >= container.scrollHeight - 10;
setUserScrolled(!atBottom);
};
container.addEventListener('scroll', handleScroll);
return () => container.removeEventListener('scroll', handleScroll);
}
}, []);
// 复制 Markdown 内容
const handleCopy = () => {
navigator.clipboard.writeText(value || '')
.then(() => alert('Markdown 已复制到剪贴板'))
.catch(err => console.error('复制失败:', err));
};
// 下载 Markdown 文件
const handleDownload = () => {
const blob = new Blob([value || ''], { type: 'text/markdown;charset=utf-8' });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = 'document.md';
document.body.appendChild(a);
a.click();
document.body.removeChild(a);
URL.revokeObjectURL(url);
};
return (
<div className="markdown-editor">
{/* <div className="markdown-toolbar">
<button className="neon-button" onClick={handleCopy}>复制 Markdown</button>
<button className="neon-button" onClick={handleDownload}>下载 Markdown</button>
</div> */}
<div
ref={containerRef}
className="markdown-preview"
dangerouslySetInnerHTML={{ __html: sanitizedHtml }}
/>
</div>
);
}
function SendRequestToBackend() {
const [inputValue, setInputValue] = useState('');
const handleSendRequest = async () => {
try {
const response = await fetch('http://localhost:8001/generate_survey', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ query: inputValue }),
});
if (!response.ok) {
throw new Error('Failed to send request');
}
const data = await response.json();
console.log('Response from backend:', data);
} catch (error) {
console.error('Error sending request:', error);
}
};
return (
<div className="request-panel" style={{ flexDirection: 'column', alignItems: 'center' }}>
<input
type="text"
value={inputValue}
onChange={(e) => setInputValue(e.target.value)}
className="neon-input"
placeholder="Enter text to send"
rows={3}
/>
<button onClick={handleSendRequest} className="neon-button">
Go!
</button>
</div>
);
}
function App() {
const [inputs, setInputs] = useState({
query: { title: 'Query', displayText: '', targetText: '', isTyping: false },
nowUpdate: { title: 'Now Update', displayText: '', targetText: '', isTyping: false },
nextUpdate: { title: 'Next Update', displayText: '', targetText: '', isTyping: false },
searchKeywords: { title: 'Search Keywords', displayText: '', targetText: '', isTyping: false },
papers: { title: 'Papers', displayText: '', targetText: '', isTyping: false },
});
const [markdownContent, setMarkdownContent] = useState('');
const inputKeyMap = {
query: inputs.query,
nowUpdate: inputs.nowUpdate,
nextUpdate: inputs.nextUpdate,
searchKeywords: inputs.searchKeywords,
papers: inputs.papers,
markdown: markdownContent
};
const updateInputsFromPostData = (postData) => {
let newMarkdownContent = markdownContent;
Object.entries(postData).forEach(([key, value]) => {
if (key in inputKeyMap) {
if (key === 'markdown') {
if (markdownContent !== value) {
newMarkdownContent = value;
setMarkdownContent(newMarkdownContent);
}
} else if (inputKeyMap[key] && inputKeyMap[key].targetText !== value) {
const updatedInput = {
...inputKeyMap[key],
targetText: value,
isTyping: true,
};
setInputs((prevInputs) => ({
...prevInputs,
[key]: updatedInput,
}));
// startTypingAnimationForTextbox(value, (newText) => {
// setInputs((prevInputs) => ({
// ...prevInputs,
// [key]: {
// ...prevInputs[key],
// displayText: newText,
// },
// }));
// });
}
}
});
};
// const startTypingAnimationForTextbox = (text, setText) => {
// setText('');
// let charIndex = 0;
// const timer = setInterval(() => {
// if (charIndex < text.length) {
// setText((prev) => prev + text[charIndex]);
// charIndex++;
// } else {
// clearInterval(timer);
// }
// }, 50); // Reduced interval for faster typing animation
// };
useEffect(() => {
const ws = new WebSocket('ws://localhost:8001/ws');
ws.onmessage = (event) => {
try {
const data = JSON.parse(event.data);
updateInputsFromPostData(data);
console.log('Received data:', data);
} catch (e) {
console.error('Invalid WebSocket message:', e);
}
};
ws.onerror = (err) => {
console.error('WebSocket error:', err);
};
// return () => ws.close();
}, []);
const leftInputs = [inputs.nowUpdate, inputs.nextUpdate,inputs.searchKeywords];
const rightInputs = [inputs.papers];
return (
<div className="cyber-container">
<div className="tech-panel left-panel">
{leftInputs.map((input, index) => (
<div key={`left-${index}`} className="input-wrapper">
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
<textarea
value={input.targetText}
readOnly
className="neon-input"
rows={Math.max(10, input.targetText.split('\n').length)}
cols={50}
style={{ resize: 'none', fontSize: '12px' }}
/>
</div>
))}
</div>
<div className="core-module">
<SendRequestToBackend />
<MarkdownEditor value={markdownContent} />
</div>
<div className="tech-panel right-panel">
{rightInputs.map((input, index) => (
<div key={`right-${index}`} className="input-wrapper">
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
<textarea
value={input.targetText}
readOnly
className="neon-input"
rows={100}
cols={50}
style={{ resize: 'none', fontSize: '12px' }}
/>
</div>
))}
</div>
</div>
);
}
export default App;

View File

@@ -0,0 +1,68 @@
:root {
font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
line-height: 1.5;
font-weight: 400;
color-scheme: light dark;
color: rgba(255, 255, 255, 0.87);
background-color: #242424;
font-synthesis: none;
text-rendering: optimizeLegibility;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
a {
font-weight: 500;
color: #646cff;
text-decoration: inherit;
}
a:hover {
color: #535bf2;
}
body {
margin: 0;
display: flex;
place-items: center;
min-width: 320px;
min-height: 100vh;
}
h1 {
font-size: 3.2em;
line-height: 1.1;
}
button {
border-radius: 8px;
border: 1px solid transparent;
padding: 0.6em 1.2em;
font-size: 1em;
font-weight: 500;
font-family: inherit;
background-color: #1a1a1a;
cursor: pointer;
transition: border-color 0.25s;
}
button:hover {
border-color: #646cff;
}
button:focus,
button:focus-visible {
outline: 4px auto -webkit-focus-ring-color;
}
@media (prefers-color-scheme: light) {
:root {
color: #213547;
background-color: #ffffff;
}
a:hover {
color: #747bff;
}
button {
background-color: #f9f9f9;
}
}

View File

@@ -0,0 +1,10 @@
import { StrictMode } from 'react'
import { createRoot } from 'react-dom/client'
import './index.css'
import App from './App.jsx'
createRoot(document.getElementById('root')).render(
<StrictMode>
<App />
</StrictMode>,
)

View File

@@ -0,0 +1,7 @@
import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react'
// https://vite.dev/config/
export default defineConfig({
plugins: [react()],
})

8
code/requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
openai
vllm
jsonlines
faiss-cpu
# faiss-gpu
fastapi
uvicorn
yarl

4
code/scripts/run.sh Normal file
View File

@@ -0,0 +1,4 @@
python ./src/generation/run.py \
--model_path "openbmb/MiniCPM4-Survey" \
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
--output_file "test.md" \

View File

@@ -0,0 +1,5 @@
python ./src/generation/run.py \
--model_path "openbmb/MiniCPM4-Survey" \
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
--output_file "test.md" \
--port 8001 \

View File

@@ -0,0 +1,810 @@
import re
import json
import copy
# BASE_SURVEY_STRUCTURE = """
# # Title: A survey of ...
# # Introduction: None.
# # Section 1: None.
# ## Subsection 1 (if needed): None.
# ## Subsection 2 (if needed): None.
# ### Subsubsection 1 (if needed): None.
# ### Subsubsection 2 (if needed): None.
# ### ...
# # Section 2: None.
# # ...
# # Conclusion: None.
# """
class SurveyManager:
BASE_SURVEY_STRUCTURE = {
"title": "",
"abstract": "",
"introduction": {
"content": ""
},
"sections": [],
"conclusion": ""
}
def __init__(self):
pass
@staticmethod
def parse_update_pos(update_pos):
"""
(1) "title", "abstract", "introduction", or "conclusion"
(2) "section-i/subsection-j/..."
"""
if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]:
return update_pos
else:
keys = update_pos.split("/")
if len(keys) == 1: # Section-?
i = int(keys[0].lower().split("section-")[-1])
return f"section-{i}"
elif len(keys) == 2: # Section-?/Subsection-?
i = int(keys[0].lower().split("section-")[-1])
j = int(keys[1].lower().split("subsection-")[-1])
return f"section-{i}/subsection-{j}"
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
i = int(keys[0].lower().split("section-")[-1])
j = int(keys[1].lower().split("subsection-")[-1])
k = int(keys[2].lower().split("subsubsection-")[-1])
return f"section-{i}/subsection-{j}/subsubsection-{k}"
else:
raise ValueError("unsupported update_pos keys")
@staticmethod
def _to_one_line(string):
if isinstance(string, dict):
if "content" in string and string["content"]:
return SurveyManager._to_one_line(string["content"])
# return SurveyManager._to_one_line(string["content"])
else:
return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip()
if not string:
return ""
else:
return string#.replace("\n", " ")
@staticmethod
def convert_survey_dict_to_str(current_survey):
string = ""
if current_survey == {}:
return "There is no survey."
# title
try:
content = SurveyManager._to_one_line(current_survey["title"])
string += f"# {content}\n"
except:
string += f"# Title: None\n"
# abstract
try:
content = SurveyManager._to_one_line(current_survey["abstract"])
string += f"## Abstract\n{content}\n"
except:
string += f"## Abstract\nNone\n"
# introduction
try:
content = SurveyManager._to_one_line(current_survey["introduction"])
string += f"## Introduction\n{content}\n"
except:
string += f"## Introduction\nNone\n"
# sections
if "sections" in current_survey:
for i, section in enumerate(current_survey["sections"]):
title_key = "name" if "name" in section else "title"
name, content = section[title_key], SurveyManager._to_one_line(section)
# string += f"# Section-{i+1} [{name}]: {content}\n"
string += f"## {name}\n{content}\n"
if "subsections" in section:
for j, subsection in enumerate(section["subsections"]):
name, content = subsection[title_key], SurveyManager._to_one_line(subsection)
# string += f" ## Subsection-{j+1} [{name}]: {content}\n"
string += f"### {name}\n{content}\n"
if "subsubsections" in subsection:
for k, subsubsection in enumerate(subsection["subsubsections"]):
name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection)
# string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
string += f"#### {name}\n{content}\n"
# conclusion
try:
content = SurveyManager._to_one_line(current_survey["conclusion"])
string += f"## Conclusion\n{content}\n"
except:
string += f"## Conclusion:\nNone\n"
return string
@staticmethod
def _abbr_one_line(string, abbr=True):
if isinstance(string, dict):
if "content" in string and string["content"]:
return SurveyManager._abbr_one_line(string["content"], abbr=abbr)
elif "plan" in string:
return "[PLAN] " + string["plan"].replace("\n", " ").strip()
else:
return ""
else:
if not string:
return ""
else:
if abbr and len(string) > 50:
return "[OK] " + string.replace("\n", " ").strip()[:50] + "..."
else:
return "[OK] " + string.replace("\n", " ").strip()
@staticmethod
def convert_survey_dict_to_abbr_str(current_survey):
string = ""
if current_survey == {}:
return "There is no survey."
# title
try:
content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False)
string += f"# Title: {content}\n"
except:
string += f"# Title: None\n"
# abstract
try:
content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False)
string += f"# Abstract: {content}\n"
except:
string += f"# Abstract: None\n"
# introduction
try:
content = SurveyManager._abbr_one_line(current_survey["introduction"])
string += f"# Introduction: {content}\n"
except:
string += f"# Introduction: None\n"
# sections
if "sections" in current_survey:
for i, section in enumerate(current_survey["sections"]):
title_key = "name" if "name" in section else "title"
name, content = section[title_key], SurveyManager._abbr_one_line(section)
string += f"# Section-{i+1} [{name}]: {content}\n"
if "subsections" in section:
for j, subsection in enumerate(section["subsections"]):
name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection)
string += f" ## Subsection-{j+1} [{name}]: {content}\n"
if "subsubsections" in subsection:
for k, subsubsection in enumerate(subsection["subsubsections"]):
name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection)
string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
# conclusion
try:
content = SurveyManager._abbr_one_line(current_survey["conclusion"])
string += f"# Conclusion: {content}\n"
except:
string += f"# Conclusion: None\n"
return string
@staticmethod
def update_one_section(sections, i, content):
# i -= 1
if i >= 0 and i <= (len(sections)-1):
sections[i]["content"] = content
return True
else:
# print("update fail!")
return False
@staticmethod
def update_current_survey(current_survey, answer) -> bool:
"""
update_pos: "section-i/subsection-j/subsubsection-k"
"""
# if answer == {}:
# return True
try:
update_pos, content = answer["update"], answer["content"]
if update_pos == "plan":
# current_survey = content
if current_survey == {}:
for k,v in content.items():
current_survey[k] = copy.deepcopy(v)
else:
return False
elif update_pos in ["conclusion", "abstract"]:
if update_pos not in current_survey:
# print("update fail!")
return False
current_survey[update_pos] = content
elif update_pos == "introduction":
if update_pos not in current_survey:
# print("update fail!")
return False
current_survey[update_pos] = {"content": content}
else:
keys = update_pos.split("/")
if len(keys) == 1: # Section-?
i = int(keys[0].lower().split("section-")[-1])-1
return SurveyManager.update_one_section(current_survey["sections"], i, content)
elif len(keys) == 2: # Section-?/Subsection-?
i = int(keys[0].lower().split("section-")[-1])-1
j = int(keys[1].lower().split("subsection-")[-1])-1
try:
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content)
except:
# print("update fail!")
return False
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
i = int(keys[0].lower().split("section-")[-1])-1
j = int(keys[1].lower().split("subsection-")[-1])-1
k = int(keys[2].lower().split("subsubsection-")[-1])-1
try:
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) # 禁用第四级
except:
# print("update fail!")
return False
else:
# print("update fail!")
# print("unsupported update_pos keys")
return False
# raise ValueError("unsupported update_pos keys")
except:
# print("update fail!")
return False
# print("answer is not a valid json object.")
# print(answer)
# raise ValueError("answer is not a valid json object.")
return True
from prompts import *
class PromptManger:
system_prompt = SYSTEM_PROMPT_0415_BUFFER
user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER
user_prompt = USER_PROMPT_0415_BUFFER
class BufferManager:
"""
Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training.
batch_rollout_data = [
{
query (or env_id): # Uniquely identifies a query or environment, [input parameter].
*running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat.
state: { # Indicates whether the process is finished.
"score": 0.0,
"done": True / False
"current_survey": dict # Structured data.
}
trajectory: [ # Organizes all data into a multi-turn interaction format.
{
step: int, 0~?, # The first step, usually includes some init_info or plan.
original_response: str, The raw output from the model, which may have various formatting issues.
answer_thought: str, # Encapsulated using the <think>...</think> block.
answer: {
"original_str": str
"update": str,
"name": str,
"content": str,
"inclusions": list, # Extracted independently?
}
tool_call_thought: str, # Encapsulated using the <think>...</think> block.
tool_call: {
"original_str": str, # Encapsulated using the <tool_call>...</tool_call> block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search".
"tool_name": str # done or search.
"keywords": list[str], Extracted search keywords from tool_call, otherwise none.
}
*papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed.
cites: list[str], # References cited by the model, which may include multiple citations.
summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY.
*prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed.
},
...
]
},
...
]
"""
def __init__(self, prompts, repeat_n: int=1):
# self.config = config
self.step = 0
self.batch_rollout_data = []
self.running_ids = [] # active envs
batch_size = prompts.batch['input_ids'].size(0)
uids = prompts.non_tensor_batch['uid']
querys = prompts.non_tensor_batch['raw_prompt'].copy()
ground_truths = prompts.non_tensor_batch['ground_truth']
# print(querys)
new_querys = []
for i_batch in range(batch_size):
raw_prompt_i_batch = querys[i_batch][-1]["content"]
new_querys.append(raw_prompt_i_batch)
querys = new_querys
assert len(querys) == len(uids)
for query, uid, ground_truth in zip(querys, uids, ground_truths):
now_survey = {}
for _ in range(repeat_n):
self.batch_rollout_data.append({
"query": query,
"uid": uid,
"state": {
# "score": 0.0, # only for debug
# "format_score": None, # will update at last step
"done": False,
"current_survey": {}
},
"trajectory": [],
"history_messages": [],
})
@staticmethod
def _build_system_prompt():
prompt = PromptManger.system_prompt
return prompt
@staticmethod
def _build_user_prompt_v0(query, current_survey):
# query
prompt = PromptManger.user_prompt_v0.replace("<user_query>", query)
# add template
prompt = prompt.replace("<init_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
return prompt
@staticmethod
def _build_user_prompt(query, current_survey, trajs):
last_traj = trajs[-1]
# query
prompt = PromptManger.user_prompt.replace("<user_query>", query)
# add current survey
prompt = prompt.replace("<current_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
# current plan
if last_traj["tool_call_thought"] == "":
prompt = prompt.replace("<last_step_thought>", "Your last thought is not available, please give new plan")
else:
prompt = prompt.replace("<last_step_thought>", last_traj["tool_call_thought"])
prompt = prompt.replace("<last_step_tool_call>", json.dumps(last_traj["tool_call"]))
# summarys
for traj in reversed(trajs):
if len(traj["summarys"]) > 0:
break
summary_num = len(traj["summarys"])
if summary_num == 0:
prompt = prompt.replace("<summarys>", "There is no result.")
else:
prompt = prompt.replace("<summarys>", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"]))
return prompt
@staticmethod
def _build_user_prompt_force_correct(query, current_survey, trajs):
if current_survey == {}:
# gen plan
now_section = "plan"
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
else:
now_section = ""
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
now_section = "abstract"
elif "content" not in current_survey["introduction"]:
now_section = "introduction"
elif "sections" in current_survey:
for section in current_survey["sections"]:
if "content" not in section:
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
break
elif "subsections" in section:
for subsection in section["subsections"]:
if "content" not in subsection:
now_section = "section-{}/subsection-{}".format(
current_survey["sections"].index(section) + 1,
section["subsections"].index(subsection) + 1
)
break
elif "subsubsections" in subsection:
for subsubsection in subsection["subsubsections"]:
if "content" not in subsubsection:
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
current_survey["sections"].index(section) + 1,
section["subsections"].index(subsection) + 1,
subsection["subsubsections"].index(subsubsection) + 1
)
break
if now_section:
break
if now_section:
break
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
now_section = "conclusion"
else:
trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
if now_section != "":
trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}"
for traj in reversed(trajs):
if len(traj["summarys"]) > 0:
break
summary_num = len(traj["summarys"])
if now_section == "plan" and summary_num == 0:
trajs[-1]["tool_call_thought"] = "I need to get enough information."
return BufferManager._build_user_prompt(query, current_survey, trajs)
@staticmethod
def _check_finalize(query, current_survey, trajs):
if current_survey == {}:
# gen plan
return False
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
else:
now_section = ""
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
now_section = "abstract"
elif "content" not in current_survey["introduction"]:
now_section = "introduction"
elif "sections" in current_survey:
for section in current_survey["sections"]:
if "content" not in section:
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
break
elif "subsections" in section:
for subsection in section["subsections"]:
if "content" not in subsection:
now_section = "section-{}/subsection-{}".format(
current_survey["sections"].index(section) + 1,
section["subsections"].index(subsection) + 1
)
break
elif "subsubsections" in subsection:
for subsubsection in subsection["subsubsections"]:
if "content" not in subsubsection:
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
current_survey["sections"].index(section) + 1,
section["subsections"].index(subsection) + 1,
subsection["subsubsections"].index(subsubsection) + 1
)
break
if now_section:
break
if now_section:
break
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
now_section = "conclusion"
# else:
# trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
if now_section != "":
return False
return True
# rule-based method: query, plan, paragraphs -> prompt -> thought, paragraph, action
def build_prompt_for_generator(self):
total_messages = []
self.running_ids = []
for running_id, data in enumerate(self.batch_rollout_data):
if data["state"]["done"]:
pass
else:
if len(data["trajectory"]) == 0: # first prompt
user_prompt = BufferManager._build_user_prompt_v0(data["query"],
data["state"]["current_survey"])
else:
if data["trajectory"][-1]["update_success"]:
user_prompt = BufferManager._build_user_prompt(data["query"],
data["state"]["current_survey"],
data["trajectory"])
else:
# user_prompt = data["history_messages"][-1][1]["content"]
user_prompt = BufferManager._build_user_prompt_force_correct(data["query"],
data["state"]["current_survey"],
data["trajectory"])
messages = [
{
"role": "system",
"content": BufferManager._build_system_prompt(),
},
{
"role": "user",
"content": user_prompt,
}
]
data["history_messages"].append(messages)
total_messages.append(messages)
self.running_ids.append(running_id) # update running ids
return total_messages
def update_all_scores(self, scores):
assert len(scores) == len(self.batch_rollout_data)
for score, log in zip(scores, self.batch_rollout_data):
log["state"]["score"] = score
def update_all_format_scores(self, scores):
assert len(scores) == len(self.batch_rollout_data)
for score, log in zip(scores, self.batch_rollout_data):
log["state"]["format_score"] = score
def update_trajectory(self, model_responses, env_feedbacks):
"""
model_response: original_response, thought, paragraph, tool_call, format_reward
env_feedback: done, search_keywards, abstracts, outcome_reward
"""
assert len(self.running_ids) == len(model_responses)
assert len(self.running_ids) == len(env_feedbacks)
for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks):
# update state
self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] # if True, finalize the task
update_success = False
if response["true"]:
if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}:
if len(response["answer"]) != 0: # no empty dict or start
update_success = SurveyManager.update_current_survey(
self.batch_rollout_data[running_id]["state"]["current_survey"],
response["answer"])
else:
# Search Then Write
if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]:
update_success = SurveyManager.update_current_survey(
self.batch_rollout_data[running_id]["state"]["current_survey"],
response["answer"])
elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0:
update_success = True
self.batch_rollout_data[running_id]["trajectory"].append({
"step": self.step,
"original_response": response["original_response"],
"answer_thought": response["answer_thought"],
"answer": response["answer"],
"tool_call_thought": response["tool_call_thought"],
"tool_call": response["tool_call"],
"search_keywords": feedback["search_keywords"],
"summarys": feedback["summarys"],
"update_success": update_success and response["true"],
})
self.batch_rollout_data[running_id]["history_messages"][-1].append({
"role": "assistant",
"content": response["original_response"],
})
if self.batch_rollout_data[running_id]["state"]["done"]:
real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"],
self.batch_rollout_data[running_id]["state"]["current_survey"],
self.batch_rollout_data[running_id]["trajectory"])
if not real_done:
self.batch_rollout_data[running_id]["state"]["done"] = False
@staticmethod
def match_reference(text:str):
reg = r"\\\w*cite(?!style)\w*\{(.+?)\}"
placeholder_reg = re.compile(r"^#\d+$")
reg_bibkeys = re.findall(reg, text)
bibkeys = set()
for bibkey in reg_bibkeys:
single_bib = bibkey.split(",")
for bib in single_bib:
if not placeholder_reg.match(bib):
bib = bib.strip()
if bib and bib != "*":
bibkeys.add(bib)
reg = r"\\nocite{(.+?)\}"
reg_bibkeys = re.findall(reg, text)
for bibkey in reg_bibkeys:
single_bib = bibkey.split(",")
for bib in single_bib:
if not placeholder_reg.match(bib):
bib = bib.strip()
if bib and bib != "*":
bibkeys.remove(bib)
ref_key_list = list(bibkeys)
return ref_key_list
@staticmethod
def parse_generator_response(response):
"""
1. 解析失败: step + 1, 重新生成, 给出提示
2. 解析成功:
2.1 tool_call == search(keywords) 发送post请求
2.2 tool_call == done 结束任务
**standard format**
Current Update:
<think> [Your Thoughts]: str </think>
<answer> {"update": str, "content": str}: dict </answer>
Next Plan:
<think> [Your Thoughts]: str </think>
<tool_call> {"tool": "search", "arguments": {}}: dict</tool_call>
"""
extracted_result = {
"original_response": response
}
try:
current_update = response.split("Current Update:")[-1].split("Next Plan:")[0]
except:
current_update = response
# pattern
think_pattern = r"<think>(.*?)</think>"
answer_pattern = r"<answer>(.*?)</answer>"
tool_pattern = r"<tool_call>(.*?)</tool_call>"
# extract information from current_update
think_match = re.search(think_pattern, current_update, re.DOTALL) # 多行提取
if think_match:
think = think_match.group(1)
think = think.strip()
else:
think = ""
extracted_result["answer_thought"] = think
answer_match = re.search(answer_pattern, current_update, re.DOTALL) # 多行提取
has_answer = False
if answer_match:
answer = answer_match.group(1)
answer = answer.strip()
try:
answer = json.loads(answer)
if not answer == {}:
assert isinstance(answer["update"], str)
answer["update"] = SurveyManager.parse_update_pos(answer["update"])
if answer["update"] == "plan":
assert isinstance(answer["content"], dict)
plan = answer["content"]
assert isinstance(plan, dict)
plan.pop("instruction",None)
keys = ["abstract", "introduction", "conclusion","sections","title"]
for key in keys:
assert key in plan
for key in plan:
assert key in keys
if key == "sections":
assert isinstance(plan[key], list)
for section in plan[key]:
assert isinstance(section, dict)
assert "plan" in section
assert "title" in section
assert isinstance(section["plan"], str)
assert isinstance(section["title"], str)
assert section["title"] != "Methodology" # 不能是MethodologyWIP
if "subsections" in section:
assert isinstance(section["subsections"], list)
for subsection in section["subsections"]:
assert isinstance(subsection, dict)
assert "plan" in subsection
assert "title" in subsection
assert isinstance(subsection["plan"], str)
assert isinstance(subsection["title"], str)
if "subsubsections" in section:
assert isinstance(subsection["subsubsections"], list)
for subsubsection in subsection["subsubsections"]:
assert isinstance(subsubsection, dict)
assert "plan" in subsubsection
assert "title" in subsubsection
assert isinstance(subsubsection["plan"], str)
assert isinstance(subsubsection["title"], str)
elif key == "title":
assert isinstance(plan[key], str)
else:
assert isinstance(plan[key], dict)
assert "plan" in plan[key]
if key not in ["abstract", "conclusion", "introduction"]:
assert "title" in plan[key]
else:
assert isinstance(answer["content"], str)
has_answer = True
except:
answer = {}
else:
answer = {}
extracted_result["answer"] = answer
# extract information from next_plan
try:
next_plan = response.split("Next Plan:")[1]
except:
try:
next_plan = response.split("</answer>")[1]
except:
next_plan = response
think_match = re.search(think_pattern, next_plan, re.DOTALL) # 多行提取
if think_match:
think = think_match.group(1)
think = think.strip()
else:
think = ""
extracted_result["tool_call_thought"] = think
tool_match = re.search(tool_pattern, next_plan, re.DOTALL) # 多行提取
has_tool_call = False
if tool_match:
tool_text = tool_match.group(1)
tool_text = tool_text.strip()
try:
tool_call = json.loads(tool_text)
assert tool_call["name"] in ["search_engine", "finalize"]
if tool_call["name"] == "search_engine":
assert isinstance(tool_call["arguments"]["query"], list)
has_tool_call = True
except:
tool_call = {}
else:
tool_call = {}
extracted_result["tool_call"] = tool_call
extracted_result["true"] = has_answer and has_tool_call
reg = r"[\u4e00-\u9fa5]"
has_chinese = re.search(reg, response) is not None
extracted_result["true"] = extracted_result["true"] and not has_chinese
return extracted_result
class BufferManager_V2(BufferManager):
def __init__(self, querys, repeat_n=1):
# self.config = config
self.step = 0
self.batch_rollout_data = []
self.running_ids = [] # active envs
for uid, query in enumerate(querys):
print("CURRENT QUERY: ", query)
for _ in range(repeat_n):
self.batch_rollout_data.append({
"query": query,
"uid": f"query_{uid}",
"state": {
# "score": 0.0, # only for debug
# "format_score": None, # will update at last step
"done": False,
"current_survey": {}
},
"trajectory": [],
"history_messages": []
})

View File

@@ -0,0 +1,81 @@
SYSTEM_PROMPT_0415_BUFFER = """You are a survey writer. You are asked to write a survey follow the instruction, refered as "Query" or "User's Query". You will finish the survey by multi-step updating.
Usually, you need to do two things:
(1) First, you need to update the survey using the retrieved information according to the current plan, refered as "Current Update". You MUST think inside <think>...</think> before you give your <answer>...</answer> action, mainly about "How to write paragraphs with citations based on retrieved information to complete the current plan?". If the current plan is None, or you think the current plan is not good, or you think the retrieved information is not enough for you to finish the plan, you can jump the "Answer" action by giving "{}" as answer. Please give the citation in \\cite{}.
(2) Then, you need decide what part of the survey needs to be updated, refered as "Next Plan". You MUST think inside <think>...</think> before you give your <tool_call>...</tool_call> action. If you think the current retrieved information is enough to finish your next plan, you can jump the "Tool Call" action by giving "{}" as tool call.
## Answer
You can give one answer to update the survey.
<answer>
{"update": <section-pos>, "content": paragraph }
</answer>
There are two parameters in <answer> action.
* update: string, which position you want to update, such as "title", "abstract", "introduction", "section-1", "section-1/subsction-1", "section-1/subsction-1/subsection-1", and "conclusion".
* content: string, the update content for the position of the survey, please give the faithful citation in \\cite{}. . Or dict, only when you give the plan of the paper, the values including the section title and a simple plan of it.
## Tool Call
You can call one function to assist the survey writing.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name": "search_engine", "description": "Search reasearch papers.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "description": "The words to search for in quotes."}, "required": ["query"]}}}}
{"type": "function", "function": {"name": "finalize", "description": "Finalize the survey.", "parameters": {}, "required": []}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call> {"name": <function-name>, "arguments": <args-json-object>} </tool_call>
For example, You can call the search engine by using:
<tool_call> {"name": "search_engine", "arguments": {"query": ["keyword-1", "keyword-2", ...]} </tool_call>
If you think the survey is finished, please call:
<tool_call> {"name": "finalize", "arguments": {} </tool_call>
** Attention **
You must use correct JSON format inside <answer>...</answer> and <tool_call>...</tool_call>, otherwise we can't extract the corrent content.
**Output format**
(1) Current Update:
<think> How to write paragraphs with citations based on retrieved information to complete the plan? </think>
<answer> Please provide your answer here. (JSON format) </answer>
(2) Next Plan:
<think> Which part of the survey needs to be updated? What information needs to be queried? </think>
<tool_call> Please call a tool here. (JSON format) </tool_call>
"""
USER_PROMPT_v0_0424_BUFFER = """Please update the survey depending on the insturctions.
**User's Query**
<user_query>
**Current Survey**
<init_survey>
**Current Plan**
<think>I need to get enough information.</think>
<tool_call>{}</tool_call>
**Retrieved Information**
There is no results.
Please give your response following the output format.
"""
USER_PROMPT_0415_BUFFER = """Please update the survey depending on the insturctions.
**User's Query**
<user_query>
**Current Survey**
<current_survey>
**Current Plan**
<think> <last_step_thought> </think>
<tool_call> <last_step_tool_call> </tool_call>
**Retrieved Information**
<summarys>
Please give your response following the output format.
"""

338
code/src/generation/run.py Normal file
View File

@@ -0,0 +1,338 @@
from contextlib import contextmanager
from codetiming import Timer
@contextmanager
def _timer(name: str, timing_raw):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last
from buffer import SurveyManager
from buffer import BufferManager_V2 as BufferManager
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import re
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import argparse
from pydantic import BaseModel
import json
import aiohttp
app = FastAPI()
# 允许跨域(如果前端和后端端口不同需要加上)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
active_connections = set()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
active_connections.add(websocket)
try:
while True:
await websocket.receive_text()
except WebSocketDisconnect:
active_connections.remove(websocket)
async def post_to_frontend(payload):
print(f"Sending payload to frontend: {payload}") # Log the payload being sent
for ws in list(active_connections):
try:
await ws.send_text(payload)
except Exception as e:
print(f"Error sending to WebSocket: {e}")
active_connections.remove(ws)
def write_to_json(data, path):
with open(path, 'w', encoding='utf8') as f:
f.write(json.dumps(data, ensure_ascii=False, indent=4))
class OriginalvLLMRollout:
def __init__(self, model_name_or_path):
# init vLLM
self.rollout_model = LLM(
model=model_name_or_path,
tokenizer=model_name_or_path,
gpu_memory_utilization=0.95,
trust_remote_code=True,
)
self.sampling_params = SamplingParams(
temperature=0.7,
top_p=0.8,
repetition_penalty=1.05,
top_k=20,
max_tokens=2748,
)
def generate(self, input_texts):
generated_texts = []
completions = self.rollout_model.generate(input_texts, self.sampling_params, use_tqdm=False)
for output in completions:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
def chat(self, input_messages):
generated_texts = []
completions = self.rollout_model.chat(input_messages, self.sampling_params, use_tqdm=False)
for output in completions:
generated_text = output.outputs[0].text
generated_texts.append(generated_text)
return generated_texts
async def rollout_with_env(querys, batch_size, max_turns, model_path, url,
deploy_port=None):
"""
Args:
querys: [string]
"""
###############################
#### splited by batch size ####
###############################
n = len(querys) // batch_size
batch_querys = []
for i in range(n+1):
temp_data = querys[i*batch_size: (i+1)*batch_size]
if len(temp_data) > 0:
batch_querys.append(temp_data)
print("QUERY NUMBER with BATCH: ", [len(x) for x in batch_querys])
###################
#### init vllm ####
###################
vllm_manager = OriginalvLLMRollout(model_path)
############################
#### init Format Reward ####
############################
tokenizer = AutoTokenizer.from_pretrained(model_path)
total_rollout_data = []
for querys in batch_querys:
###########################################
#### acquire env configs and init envs ####
###########################################
buffer_manager = BufferManager(querys)
while True:
# Break at max-turns
if buffer_manager.step >= max_turns:
break
###############################
#### prepare input prompts ####
###############################
messagess_todo = buffer_manager.build_prompt_for_generator()
# breakpoint()
# Break when no tasks
if len(messagess_todo) == 0:
break
##########################
#### generate by vLLM ####
##########################
timing_raw = {}
with _timer('vllm sampling', timing_raw):
# response_texts = vllm_manager.chat(messagess_todo)
response_texts = await asyncio.to_thread(vllm_manager.chat, messagess_todo)
##################################
#### preprocess the responses ####
##################################
# 对response的详细处理可以集成到环境类中因环境而异, 先对Response进行预处理
extracted_results = []
for response_text in response_texts:
result = BufferManager.parse_generator_response(response_text)
extracted_results.append(result)
#################################################
#### execute in environment and get feedback ####
#################################################
payload = {
"tool_calls": [x["tool_call"] for x in extracted_results]
}
if buffer_manager.step <=2:
payload["topk"] = 20
with _timer('get env feedback', timing_raw):
# env_response_batched = requests.post(url, json=payload).json()
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload) as resp:
env_response_batched = await resp.json()
###################################
#### postprocess the feedbacks ####
###################################
with _timer('postprocessing', timing_raw):
buffer_manager.update_trajectory(extracted_results, env_response_batched)
buffer_manager.step += 1
print(timing_raw)
if deploy_port is not None:
now_text = json_to_markdown(buffer_manager.batch_rollout_data[-1])
now_search_keywords= buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["search_keywords"]
now_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["answer_thought"]
next_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["tool_call_thought"]
now_query = buffer_manager.batch_rollout_data[-1]["query"]
trajs = buffer_manager.batch_rollout_data[-1]["trajectory"]
updated_success = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["update_success"]
if updated_success:
for traj in reversed(trajs):
if len(traj["summarys"]) > 0:
break
summary_num = len(traj["summarys"])
if summary_num == 0:
summary_text = "No summaries yet."
else:
summary_text = "\n".join(traj["summarys"])
frontend_payload = {
"markdown": now_text,
"searchKeywords": now_search_keywords,
"nowUpdate": now_update,
"nextUpdate": next_update,
"query": now_query,
"papers": summary_text
}
frontend_payload = json.dumps(frontend_payload, ensure_ascii=False)
try:
await post_to_frontend(frontend_payload)
except Exception as e:
print(f"Error posting to frontend: {e}")
for item in buffer_manager.batch_rollout_data:
item["survey_text"] = SurveyManager.convert_survey_dict_to_str(item["state"]["current_survey"])
total_rollout_data.extend(buffer_manager.batch_rollout_data)
#####################################
#### clear all envs and shutdown ####
#####################################
del buffer_manager
return total_rollout_data
def json_to_markdown(json_data):
text = SurveyManager.convert_survey_dict_to_str(json_data["state"]["current_survey"])
all_summarys = {}
for traj in json_data["trajectory"]:
for item in traj["summarys"]:
split_text = item.split("\n")
bibkey = split_text[0].split(":")[1].strip()
title_begin_index = item.find("Title:") + len("Title:")
title_end_index = item.find("Abstract:")
title = item[title_begin_index:title_end_index].strip()
arxivid = bibkey.split("arxivid")[-1].strip()
html = f"arxiv.org/abs/{arxivid}"
all_summarys[bibkey] = f"[{title}](https://{html})"
reg = r"\\cite\{(.+?)\}"
placeholder_reg = re.compile(r"^#\d+$")
reg_bibkeys = re.findall(reg, text)
bibkeys = []
for bibkey in reg_bibkeys:
single_bib = bibkey.split(",")
for bib in single_bib:
if not placeholder_reg.match(bib):
bib = bib.strip()
if bib and bib != "*" and bib not in bibkeys:
bibkeys.append(bib)
bibkeys_index = {bibkey: i+1 for i, bibkey in enumerate(bibkeys)}
def replace_bibkey(bibkey):
bibkey = bibkey.group(1)
single_bib = bibkey.split(",")
new_bibs = []
for bib in single_bib:
if not placeholder_reg.match(bib):
bib = bib.strip()
if bib and bib != "*":
if bib in bibkeys_index:
new_bibs.append(f"{bibkeys_index[bib]}")
else:
print(f"Warning: {bib} not found in bibkeys")
if len(new_bibs) > 0:
return "[" + ",".join(new_bibs) + "]"
else:
return ""
text = re.sub(reg, replace_bibkey, text)
reference_text = "\n\n".join([f"[{i}] {all_summarys[bibkey]}" for bibkey, i in bibkeys_index.items()])
text += "\n## References\n" + reference_text
return text
async def test_surveyGen(model_path, out_path,querys, url, deploy_port=None):
total_rollout_data = await rollout_with_env(querys, 1, 1000, model_path, url, deploy_port)
all_md_texts = []
for json_data in total_rollout_data:
md_text = json_to_markdown(json_data)
all_md_texts.append(md_text)
all_md_texts = "\n\n".join(all_md_texts)
with open(out_path, 'w', encoding='utf8') as f:
f.write(all_md_texts)
# with jsonlines.open(out_path, 'w') as writer:
# for item in total_rollout_data:
# writer.write(item)
class QueryRequest(BaseModel):
query: str
@app.post("/generate_survey")
async def generate_survey(request: QueryRequest):
global args # Ensure args is accessible
# 这里可以根据需要处理查询
model_path = args.model_path
out_path = args.output_file
query = request.query
querys = [query] # 将查询转换为列表
url = args.retriver_url
deploy_port = args.port if args.port is not None else None
try:
await test_surveyGen(model_path, out_path, querys, url, deploy_port)
return {"status": "success", "message": "Survey generated successfully."}
except Exception as e:
print(f"Error generating survey: {e}")
return {"status": "error", "message": str(e)}
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run survey generation with vLLM.")
parser.add_argument("--model_path", type=str, required=True, help="Path to the model.")
parser.add_argument("--query", type=str, required=True, help="Query to generate survey.")
parser.add_argument("--output_file", type=str, required=True, help="Path to the output Markdown file.")
parser.add_argument("--retriver_url", type=str, default="http://localhost:8400", help="URL of the retriever service.")
parser.add_argument("--port", type=str, default=None, help="Deploy port, default is None, which means not deploy.")
args = parser.parse_args()
if args.port is not None:
import uvicorn
uvicorn.run(app, host="localhost", port=int(args.port))# log_level="debug")
# Run the survey generation
else:
asyncio.run(
test_surveyGen(
model_path=args.model_path,
out_path=args.output_file,
querys=[args.query],
url=args.retriver_url
)
)

View File

@@ -0,0 +1,55 @@
import torch.distributed
import faiss
import pandas as pd
import faiss
import numpy as np
import jsonlines, json
from transformers import AutoModel
import os
import torch
'''
data format:
{
"bibkey": "some_bibkey",
"text": "The abstract or text of the paper."
}
example:
{
"bibkey": "arxivid1234.5678",
"text": "Title: A Study on Something\nAbstract: This paper discusses the findings of a study on something important in the field of research.\nAuthors: John Doe"
}
'''
model_name = "openbmb/MiniCPM-Embedding-Light"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
input_path = "./data/arxiv.jsonl"
with jsonlines.open(input_path) as f:
survey_data = list(f)
xids = [item["bibkey"] for item in survey_data]
passages = [item["text"] for item in survey_data]
embeddings_doc_dense, _ = model.encode_corpus(passages, max_length=1024)
# faiss save index
index = faiss.IndexFlatIP(embeddings_doc_dense.shape[1])
id_map_index = faiss.IndexIDMap(index)
index = faiss.index_cpu_to_all_gpus(id_map_index)
x_ids_int = np.array(np.arange(len(xids)))
str_int_ids = {}
for i in range(len(xids)):
str_int_ids[xids[i]] = x_ids_int[i]
str_int_ids_df = pd.DataFrame(str_int_ids, index=[0]).T.reset_index()
str_int_ids_df.columns = ["str_id", "int_id"]
str_int_ids_df.to_csv("./index/str_int_ids_abstract.csv", index=False)
index.add_with_ids(embeddings_doc_dense, x_ids_int)
index = faiss.index_gpu_to_cpu(index)
faiss.write_index(index, "./index/index_abstract.faiss")

View File

@@ -0,0 +1,21 @@
# curl -L -o ~/Downloads/arxiv.zip\
# https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
import jsonlines
input_path = './data/arxiv-metadata-oai-snapshot.json'
output_path = './data/arxiv.jsonl'
new_data = []
with jsonlines.open(input_path, 'r') as reader:
for item in reader:
new_item = {
'bibkey': f"arxivid{item['id']}",
'text': f"Title: {item['title']}\nAbstract: {item['abstract']}\nAuthors: {item['authors']}",
}
new_data.append(new_item)
with jsonlines.open(output_path, 'w') as writer:
for item in new_data:
writer.write(item)

View File

@@ -0,0 +1,175 @@
import faiss
from fastapi import FastAPI
import torch
import pandas as pd
from collections import defaultdict
import pandas as pd
import jsonlines
from transformers import AutoModel, AutoTokenizer
import uvicorn
import asyncio
from pydantic import BaseModel
from typing import List, Optional
import re
import json
import asyncio
import argparse
app = FastAPI()
model_name = "openbmb/MiniCPM-Embedding-Light"
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
co = faiss.GpuMultipleClonerOptions()
co.shard = True
co.useFloat16 = True
faiss_index_path = "./index/index_abstract.faiss" # Replace with your FAISS index path"
faiss_index = faiss.read_index(faiss_index_path)
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index,co=co)
corpus_path = "./data/arxiv.jsonl"
with jsonlines.open(corpus_path) as f:
paper_data = list(f)
paper_dict = {}
item_key = "text"
index_path = "./index/str_int_ids_abstract.csv"
index_df = pd.read_csv(index_path,converters={0: lambda x: str(x),1: lambda x: int(x)})
index_df.columns = ["str_id", "int_id"]
index_dict = index_df.set_index("int_id")["str_id"].to_dict()
for item in paper_data:
paper_dict[item["bibkey"]] = item[item_key]
class QueryRequest(BaseModel):
queries: List[str]
topk: Optional[int] = None
return_scores: bool = False
class MessageRequest(BaseModel):
tool_calls: List
topk: Optional[int] = 10
@app.post("/")
async def search_text_batch(request:MessageRequest):
tool_calls = request.tool_calls
topk = request.topk
results = []
finalize_indices = []
search_engine_indices = []
for i in range(len(tool_calls)):
try:
tool_calls[i]["name"]
except KeyError:
finalize_indices.append(i)
continue
if tool_calls[i]["name"] == "search_engine":
search_engine_indices.append(i)
elif tool_calls[i]["name"] == "finalize":
finalize_indices.append(i)
else:
finalize_indices.append(i)
tasks = []
for i in range(len(tool_calls)):
if i in search_engine_indices:
tasks.append(call_search_engine(tool_calls[i], topk))
search_task_results = await asyncio.gather(*tasks)
num_search = 0
num_finalize = 0
for i in range(len(tool_calls)):
if i in finalize_indices:
search_keywords, bibkeys,abstracts, done, score = "",[], [], True, 0.0
num_finalize += 1
elif i in search_engine_indices:
search_keywords, bibkeys, abstracts, done, score = search_task_results[num_search]
num_search += 1
titles = []
for abstract in abstracts:
try:
title = abstract.split("\n")[1]
title = title.split(":")[1].strip()
titles.append(title)
except:
titles.append("")
results.append({ "search_keywords":search_keywords, "summarys":abstracts, "done":done, "score":score, "titles":titles, "bibkeys":bibkeys})
return results
def extract_tool_call(text: str):
text = text.strip()
pattern = r"<tool_call>(.*?)</tool_call>"
match = re.search(pattern, text, re.DOTALL)
if not match:
return None
tool_text = match.group(1)
try:
tool_call = json.loads(tool_text)
except json.JSONDecodeError:
return None
return tool_call if isinstance(tool_call, dict) else None
def get_response(queries,ref):
text_raw = paper_dict[str(ref)]
text_raw = tokenizer(text_raw, max_length=8192, truncation=True)
text_raw = tokenizer.decode(text_raw["input_ids"])
response = text_raw
response = f"bibkey: {str(ref)}\n"+response
return response
async def call_search_engine(tool_call, topk=10):
try:
queries = tool_call["arguments"]["query"]
if isinstance(queries, str):
queries = [queries]
else:
queries = list(queries)
if len(queries) == 0:
return "", [], [], False, 0.0
results = defaultdict(dict)
query_embedding_to_text,_ = model.encode_query(queries, max_length=512, show_progress_bar=False)
_,results = faiss_index.search(query_embedding_to_text, topk)
result2query = {}
merge_rrf = defaultdict(float)
for i in range(len(results)):
for j in range(len(results[i])):
merge_rrf[results[i][j]] += 1/(j+1)
result2query[results[i][j]] = queries[i]
results = sorted(merge_rrf.items(), key=lambda x: x[1], reverse=True)
results = [x[0] for x in results][:topk]
# new_queries = [result2query[result] for result in results]
queries = ",".join(queries)
# bibkeys = [str(results[i]) for i in range(len(results))]
bibkeys = [str(index_dict[results[i]]) for i in range(len(results))]
response = [f"bibkey: {bibkey}\n{paper_dict[bibkey]}" for bibkey in bibkeys]
return queries,bibkeys , response, False, 0.0
except Exception as e:
print(f"Error in call_search_engine: {e}")
return "",[], [], False, 0.0
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the FastAPI application.")
parser.add_argument("--port", type=int, default=8400, help="Port to run the FastAPI application on.")
args = parser.parse_args()
uvicorn.run(app, host="0.0.0.0", port=args.port)