初始化项目,由ModelHub XC社区提供模型
Model: OpenBMB/MiniCPM4-Survey Source: Original Platform
This commit is contained in:
24
code/frontend/minicpm4-survey/.gitignore
vendored
Normal file
24
code/frontend/minicpm4-survey/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
33
code/frontend/minicpm4-survey/eslint.config.js
Normal file
33
code/frontend/minicpm4-survey/eslint.config.js
Normal file
@@ -0,0 +1,33 @@
|
||||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
|
||||
export default [
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
files: ['**/*.{js,jsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
parserOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
ecmaFeatures: { jsx: true },
|
||||
sourceType: 'module',
|
||||
},
|
||||
},
|
||||
plugins: {
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...js.configs.recommended.rules,
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'no-unused-vars': ['error', { varsIgnorePattern: '^[A-Z_]' }],
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
13
code/frontend/minicpm4-survey/index.html
Normal file
13
code/frontend/minicpm4-survey/index.html
Normal file
@@ -0,0 +1,13 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/png" href="/openbmb.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>MiniCPM4-Survey</title>
|
||||
</head>
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
2822
code/frontend/minicpm4-survey/package-lock.json
generated
Normal file
2822
code/frontend/minicpm4-survey/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
29
code/frontend/minicpm4-survey/package.json
Normal file
29
code/frontend/minicpm4-survey/package.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
"name": "minicpm4-survey",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"dompurify": "^3.2.6",
|
||||
"marked": "^15.0.12",
|
||||
"react": "^19.1.0",
|
||||
"react-dom": "^19.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.25.0",
|
||||
"@types/react": "^19.1.2",
|
||||
"@types/react-dom": "^19.1.2",
|
||||
"@vitejs/plugin-react": "^4.4.1",
|
||||
"eslint": "^9.25.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.19",
|
||||
"globals": "^16.0.0",
|
||||
"vite": "^6.3.5"
|
||||
}
|
||||
}
|
||||
BIN
code/frontend/minicpm4-survey/public/openbmb.png
Normal file
BIN
code/frontend/minicpm4-survey/public/openbmb.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 629 B |
318
code/frontend/minicpm4-survey/src/App.css
Normal file
318
code/frontend/minicpm4-survey/src/App.css
Normal file
@@ -0,0 +1,318 @@
|
||||
:root {
|
||||
--neon-blue: #00f3ff;
|
||||
--neon-purple: #ffffff;
|
||||
--dark-panel: rgba(10, 10, 30, 0.95);
|
||||
--glass-color: rgba(255, 255, 255, 0.05);
|
||||
}
|
||||
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: 'Orbitron', monospace;
|
||||
}
|
||||
|
||||
body {
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.cyber-container {
|
||||
position: relative;
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
background: radial-gradient(circle at center, #0a0a1f 0%, #000000 100%);
|
||||
padding: 2rem;
|
||||
gap: 2rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* 全息背景 */
|
||||
.hologram-bg {
|
||||
position: fixed;
|
||||
top: -50%;
|
||||
left: -50%;
|
||||
width: 200%;
|
||||
height: 200%;
|
||||
background: repeating-linear-gradient(
|
||||
45deg,
|
||||
transparent,
|
||||
transparent 5px,
|
||||
rgba(0, 255, 255, 0.05) 5px,
|
||||
rgba(0, 255, 255, 0.05) 10px
|
||||
);
|
||||
animation: scan 20s linear infinite;
|
||||
z-index: 0;
|
||||
}
|
||||
|
||||
@keyframes scan {
|
||||
0% { transform: translateY(-50%) rotate(0deg); }
|
||||
100% { transform: translateY(50%) rotate(360deg); }
|
||||
}
|
||||
|
||||
/* 面板样式 */
|
||||
.tech-panel {
|
||||
position: relative;
|
||||
width: 220px;
|
||||
padding: 1.5rem;
|
||||
background: var(--glass-color);
|
||||
backdrop-filter: blur(10px);
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 0 20px rgba(0, 255, 255, 0.2);
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.tech-panel::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 2px;
|
||||
background: linear-gradient(90deg, var(--neon-blue), var(--neon-purple), var(--neon-blue));
|
||||
animation: pulse 3s infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { opacity: 0.5; }
|
||||
50% { opacity: 1; }
|
||||
}
|
||||
|
||||
/* 输入框 */
|
||||
.input-wrapper {
|
||||
position: relative;
|
||||
margin-bottom: 1.2rem;
|
||||
}
|
||||
|
||||
.neon-input {
|
||||
width: 100%;
|
||||
padding: 0.8rem 1rem;
|
||||
background: rgba(255, 255, 255, 0.03);
|
||||
border: 1px solid rgba(0, 255, 255, 0.3);
|
||||
border-radius: 6px;
|
||||
color: #fff;
|
||||
font-size: 1rem;
|
||||
outline: none;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.neon-input::placeholder {
|
||||
color: rgba(255, 255, 255, 0.3);
|
||||
}
|
||||
|
||||
.neon-input:focus {
|
||||
border-color: var(--neon-blue);
|
||||
box-shadow: 0 0 10px var(--neon-blue);
|
||||
}
|
||||
|
||||
.input-glow {
|
||||
position: absolute;
|
||||
bottom: -5px;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 2px;
|
||||
background: linear-gradient(90deg, transparent, var(--neon-purple), transparent);
|
||||
animation: glowPulse 2s infinite;
|
||||
}
|
||||
|
||||
@keyframes glowPulse {
|
||||
0%, 100% { opacity: 0; }
|
||||
50% { opacity: 1; }
|
||||
}
|
||||
|
||||
/* 核心区域 */
|
||||
.core-module {
|
||||
position: relative;
|
||||
flex: 1;
|
||||
z-index: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.quantum-textarea {
|
||||
flex: 1;
|
||||
padding: 2rem;
|
||||
font-size: 1.2rem;
|
||||
background: rgba(0, 0, 20, 0.7);
|
||||
border: 2px solid var(--neon-blue);
|
||||
border-radius: 12px;
|
||||
color: #fff;
|
||||
outline: none;
|
||||
resize: both;
|
||||
min-height: 300px;
|
||||
backdrop-filter: blur(5px);
|
||||
font-family: 'Orbitron', monospace;
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
|
||||
.quantum-textarea::placeholder {
|
||||
color: rgba(255, 255, 255, 0.2);
|
||||
}
|
||||
|
||||
.quantum-textarea:focus {
|
||||
border-color: var(--neon-purple);
|
||||
box-shadow: 0 0 20px var(--neon-purple);
|
||||
}
|
||||
|
||||
.core-glow {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
width: 300%;
|
||||
height: 300%;
|
||||
background: radial-gradient(circle, var(--neon-blue) 0%, transparent 70%);
|
||||
opacity: 0.1;
|
||||
transform: translate(-50%, -50%);
|
||||
z-index: 0;
|
||||
pointer-events: none;
|
||||
animation: pulseGlow 5s infinite alternate;
|
||||
}
|
||||
|
||||
@keyframes pulseGlow {
|
||||
0% { transform: translate(-50%, -50%) scale(1); }
|
||||
100% { transform: translate(-50%, -50%) scale(1.2); }
|
||||
}
|
||||
|
||||
/* 数据流动画 */
|
||||
.data-stream {
|
||||
position: absolute;
|
||||
bottom: 1rem;
|
||||
right: 2rem;
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
z-index: 2;
|
||||
}
|
||||
|
||||
.stream-pulse {
|
||||
width: 6px;
|
||||
height: 6px;
|
||||
background: var(--neon-blue);
|
||||
border-radius: 50%;
|
||||
animation: pulseDot 1.5s infinite;
|
||||
}
|
||||
|
||||
.delay-1 {
|
||||
animation-delay: 0.3s;
|
||||
}
|
||||
|
||||
.delay-2 {
|
||||
animation-delay: 0.6s;
|
||||
}
|
||||
|
||||
@keyframes pulseDot {
|
||||
0% { transform: scale(1); opacity: 1; }
|
||||
70% { transform: scale(1.5); opacity: 0.3; }
|
||||
100% { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* 响应式设计 */
|
||||
@media (max-width: 1024px) {
|
||||
.cyber-container {
|
||||
flex-direction: column;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.tech-panel {
|
||||
width: 100%;
|
||||
margin-bottom: 1.5rem;
|
||||
}
|
||||
|
||||
.core-module {
|
||||
min-height: 400px;
|
||||
}
|
||||
}
|
||||
.markdown-editor {
|
||||
display: flex;
|
||||
gap: 20px;
|
||||
height: 100%;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.markdown-input,
|
||||
.markdown-preview {
|
||||
flex: 1;
|
||||
padding: 15px;
|
||||
font-size: 16px;
|
||||
border: 2px solid var(--neon-blue);
|
||||
border-radius: 8px;
|
||||
background: rgba(0, 0, 20, 0.7);
|
||||
color: #fff;
|
||||
font-family: 'Orbitron', monospace;
|
||||
resize: both;
|
||||
}
|
||||
|
||||
.markdown-input {
|
||||
min-height: 300px;
|
||||
outline: none;
|
||||
}
|
||||
|
||||
.markdown-input:focus {
|
||||
box-shadow: 0 0 10px var(--neon-blue);
|
||||
}
|
||||
|
||||
.markdown-preview {
|
||||
max-height: 800px; /* 设置最大高度,超出后滚动 */
|
||||
width: 100%;
|
||||
overflow-y: auto; /* 垂直方向溢出时显示滚动条 */
|
||||
padding: 10px;
|
||||
border: 1px solid #333;
|
||||
background-color: #111;
|
||||
color: #eee;
|
||||
font-family: monospace;
|
||||
}
|
||||
|
||||
/* Markdown 内容增强样式 */
|
||||
.markdown-preview h1, h2, h3 {
|
||||
color: var(--neon-purple);
|
||||
}
|
||||
|
||||
.markdown-preview pre {
|
||||
background: #111;
|
||||
padding: 10px;
|
||||
border-radius: 6px;
|
||||
color: #00ffcc;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-preview code {
|
||||
background: rgba(255, 255, 255, 0.1);
|
||||
padding: 2px 4px;
|
||||
border-radius: 4px;
|
||||
color: var(--neon-blue);
|
||||
}
|
||||
|
||||
.core-module {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1rem;
|
||||
padding: 1rem;
|
||||
}
|
||||
|
||||
.markdown-toolbar {
|
||||
display: flex;
|
||||
gap: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.neon-button {
|
||||
background-color: #000;
|
||||
color: #0f0;
|
||||
border: 2px solid #0f0;
|
||||
padding: 6px 12px;
|
||||
font-size: 14px;
|
||||
cursor: pointer;
|
||||
transition: all 0.2s ease-in-out;
|
||||
}
|
||||
|
||||
.neon-button:hover {
|
||||
background-color: #0f0;
|
||||
color: #000;
|
||||
}
|
||||
|
||||
.markdown-editor {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 10px;
|
||||
padding: 1rem;
|
||||
}
|
||||
259
code/frontend/minicpm4-survey/src/App.jsx
Normal file
259
code/frontend/minicpm4-survey/src/App.jsx
Normal file
@@ -0,0 +1,259 @@
|
||||
import React, { useState, useEffect, useMemo, useRef } from 'react';
|
||||
import './App.css';
|
||||
import DOMPurify from 'dompurify';
|
||||
import { marked } from 'marked';
|
||||
|
||||
// 自定义 hook:防抖
|
||||
function useDebounce(value, delay) {
|
||||
const [debouncedValue, setDebouncedValue] = useState(value);
|
||||
|
||||
React.useEffect(() => {
|
||||
const handler = setTimeout(() => {
|
||||
setDebouncedValue(value);
|
||||
}, delay);
|
||||
|
||||
return () => clearTimeout(handler);
|
||||
}, [value, delay]);
|
||||
|
||||
return debouncedValue;
|
||||
}
|
||||
|
||||
function MarkdownEditor({ value }) {
|
||||
const containerRef = useRef(null);
|
||||
|
||||
const htmlContent = marked(value || '');
|
||||
const sanitizedHtml = DOMPurify.sanitize(htmlContent);
|
||||
|
||||
const [userScrolled, setUserScrolled] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (container && !userScrolled) {
|
||||
requestAnimationFrame(() => {
|
||||
container.scrollTop = container.scrollHeight;
|
||||
});
|
||||
}
|
||||
}, [value, userScrolled]);
|
||||
|
||||
useEffect(() => {
|
||||
const container = containerRef.current;
|
||||
if (container) {
|
||||
const handleScroll = () => {
|
||||
const atBottom = container.scrollTop + container.clientHeight >= container.scrollHeight - 10;
|
||||
setUserScrolled(!atBottom);
|
||||
};
|
||||
container.addEventListener('scroll', handleScroll);
|
||||
return () => container.removeEventListener('scroll', handleScroll);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// 复制 Markdown 内容
|
||||
const handleCopy = () => {
|
||||
navigator.clipboard.writeText(value || '')
|
||||
.then(() => alert('Markdown 已复制到剪贴板'))
|
||||
.catch(err => console.error('复制失败:', err));
|
||||
};
|
||||
|
||||
// 下载 Markdown 文件
|
||||
const handleDownload = () => {
|
||||
const blob = new Blob([value || ''], { type: 'text/markdown;charset=utf-8' });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'document.md';
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
document.body.removeChild(a);
|
||||
URL.revokeObjectURL(url);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="markdown-editor">
|
||||
{/* <div className="markdown-toolbar">
|
||||
<button className="neon-button" onClick={handleCopy}>复制 Markdown</button>
|
||||
<button className="neon-button" onClick={handleDownload}>下载 Markdown</button>
|
||||
</div> */}
|
||||
<div
|
||||
ref={containerRef}
|
||||
className="markdown-preview"
|
||||
dangerouslySetInnerHTML={{ __html: sanitizedHtml }}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
function SendRequestToBackend() {
|
||||
const [inputValue, setInputValue] = useState('');
|
||||
|
||||
const handleSendRequest = async () => {
|
||||
try {
|
||||
const response = await fetch('http://localhost:8001/generate_survey', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({ query: inputValue }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to send request');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log('Response from backend:', data);
|
||||
} catch (error) {
|
||||
console.error('Error sending request:', error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="request-panel" style={{ flexDirection: 'column', alignItems: 'center' }}>
|
||||
<input
|
||||
type="text"
|
||||
value={inputValue}
|
||||
onChange={(e) => setInputValue(e.target.value)}
|
||||
className="neon-input"
|
||||
placeholder="Enter text to send"
|
||||
rows={3}
|
||||
/>
|
||||
<button onClick={handleSendRequest} className="neon-button">
|
||||
Go!
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
function App() {
|
||||
const [inputs, setInputs] = useState({
|
||||
query: { title: 'Query', displayText: '', targetText: '', isTyping: false },
|
||||
nowUpdate: { title: 'Now Update', displayText: '', targetText: '', isTyping: false },
|
||||
nextUpdate: { title: 'Next Update', displayText: '', targetText: '', isTyping: false },
|
||||
searchKeywords: { title: 'Search Keywords', displayText: '', targetText: '', isTyping: false },
|
||||
papers: { title: 'Papers', displayText: '', targetText: '', isTyping: false },
|
||||
});
|
||||
|
||||
const [markdownContent, setMarkdownContent] = useState('');
|
||||
|
||||
const inputKeyMap = {
|
||||
query: inputs.query,
|
||||
nowUpdate: inputs.nowUpdate,
|
||||
nextUpdate: inputs.nextUpdate,
|
||||
searchKeywords: inputs.searchKeywords,
|
||||
papers: inputs.papers,
|
||||
markdown: markdownContent
|
||||
};
|
||||
|
||||
const updateInputsFromPostData = (postData) => {
|
||||
let newMarkdownContent = markdownContent;
|
||||
|
||||
Object.entries(postData).forEach(([key, value]) => {
|
||||
if (key in inputKeyMap) {
|
||||
if (key === 'markdown') {
|
||||
if (markdownContent !== value) {
|
||||
newMarkdownContent = value;
|
||||
setMarkdownContent(newMarkdownContent);
|
||||
}
|
||||
} else if (inputKeyMap[key] && inputKeyMap[key].targetText !== value) {
|
||||
const updatedInput = {
|
||||
...inputKeyMap[key],
|
||||
targetText: value,
|
||||
isTyping: true,
|
||||
};
|
||||
setInputs((prevInputs) => ({
|
||||
...prevInputs,
|
||||
[key]: updatedInput,
|
||||
}));
|
||||
|
||||
// startTypingAnimationForTextbox(value, (newText) => {
|
||||
// setInputs((prevInputs) => ({
|
||||
// ...prevInputs,
|
||||
// [key]: {
|
||||
// ...prevInputs[key],
|
||||
// displayText: newText,
|
||||
// },
|
||||
// }));
|
||||
// });
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
};
|
||||
|
||||
// const startTypingAnimationForTextbox = (text, setText) => {
|
||||
// setText('');
|
||||
// let charIndex = 0;
|
||||
// const timer = setInterval(() => {
|
||||
// if (charIndex < text.length) {
|
||||
// setText((prev) => prev + text[charIndex]);
|
||||
// charIndex++;
|
||||
// } else {
|
||||
// clearInterval(timer);
|
||||
// }
|
||||
// }, 50); // Reduced interval for faster typing animation
|
||||
// };
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
const ws = new WebSocket('ws://localhost:8001/ws');
|
||||
ws.onmessage = (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data);
|
||||
updateInputsFromPostData(data);
|
||||
console.log('Received data:', data);
|
||||
} catch (e) {
|
||||
console.error('Invalid WebSocket message:', e);
|
||||
}
|
||||
};
|
||||
ws.onerror = (err) => {
|
||||
console.error('WebSocket error:', err);
|
||||
};
|
||||
// return () => ws.close();
|
||||
}, []);
|
||||
|
||||
const leftInputs = [inputs.nowUpdate, inputs.nextUpdate,inputs.searchKeywords];
|
||||
const rightInputs = [inputs.papers];
|
||||
|
||||
return (
|
||||
<div className="cyber-container">
|
||||
<div className="tech-panel left-panel">
|
||||
{leftInputs.map((input, index) => (
|
||||
<div key={`left-${index}`} className="input-wrapper">
|
||||
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
|
||||
<textarea
|
||||
value={input.targetText}
|
||||
readOnly
|
||||
className="neon-input"
|
||||
rows={Math.max(10, input.targetText.split('\n').length)}
|
||||
cols={50}
|
||||
style={{ resize: 'none', fontSize: '12px' }}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="core-module">
|
||||
<SendRequestToBackend />
|
||||
<MarkdownEditor value={markdownContent} />
|
||||
</div>
|
||||
|
||||
<div className="tech-panel right-panel">
|
||||
{rightInputs.map((input, index) => (
|
||||
<div key={`right-${index}`} className="input-wrapper">
|
||||
<h3 className="input-title" style={{ fontSize: '14px' }}>{input.title}</h3>
|
||||
<textarea
|
||||
value={input.targetText}
|
||||
readOnly
|
||||
className="neon-input"
|
||||
rows={100}
|
||||
cols={50}
|
||||
style={{ resize: 'none', fontSize: '12px' }}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default App;
|
||||
68
code/frontend/minicpm4-survey/src/index.css
Normal file
68
code/frontend/minicpm4-survey/src/index.css
Normal file
@@ -0,0 +1,68 @@
|
||||
:root {
|
||||
font-family: system-ui, Avenir, Helvetica, Arial, sans-serif;
|
||||
line-height: 1.5;
|
||||
font-weight: 400;
|
||||
|
||||
color-scheme: light dark;
|
||||
color: rgba(255, 255, 255, 0.87);
|
||||
background-color: #242424;
|
||||
|
||||
font-synthesis: none;
|
||||
text-rendering: optimizeLegibility;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
}
|
||||
|
||||
a {
|
||||
font-weight: 500;
|
||||
color: #646cff;
|
||||
text-decoration: inherit;
|
||||
}
|
||||
a:hover {
|
||||
color: #535bf2;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
display: flex;
|
||||
place-items: center;
|
||||
min-width: 320px;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 3.2em;
|
||||
line-height: 1.1;
|
||||
}
|
||||
|
||||
button {
|
||||
border-radius: 8px;
|
||||
border: 1px solid transparent;
|
||||
padding: 0.6em 1.2em;
|
||||
font-size: 1em;
|
||||
font-weight: 500;
|
||||
font-family: inherit;
|
||||
background-color: #1a1a1a;
|
||||
cursor: pointer;
|
||||
transition: border-color 0.25s;
|
||||
}
|
||||
button:hover {
|
||||
border-color: #646cff;
|
||||
}
|
||||
button:focus,
|
||||
button:focus-visible {
|
||||
outline: 4px auto -webkit-focus-ring-color;
|
||||
}
|
||||
|
||||
@media (prefers-color-scheme: light) {
|
||||
:root {
|
||||
color: #213547;
|
||||
background-color: #ffffff;
|
||||
}
|
||||
a:hover {
|
||||
color: #747bff;
|
||||
}
|
||||
button {
|
||||
background-color: #f9f9f9;
|
||||
}
|
||||
}
|
||||
10
code/frontend/minicpm4-survey/src/main.jsx
Normal file
10
code/frontend/minicpm4-survey/src/main.jsx
Normal file
@@ -0,0 +1,10 @@
|
||||
import { StrictMode } from 'react'
|
||||
import { createRoot } from 'react-dom/client'
|
||||
import './index.css'
|
||||
import App from './App.jsx'
|
||||
|
||||
createRoot(document.getElementById('root')).render(
|
||||
<StrictMode>
|
||||
<App />
|
||||
</StrictMode>,
|
||||
)
|
||||
7
code/frontend/minicpm4-survey/vite.config.js
Normal file
7
code/frontend/minicpm4-survey/vite.config.js
Normal file
@@ -0,0 +1,7 @@
|
||||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
|
||||
// https://vite.dev/config/
|
||||
export default defineConfig({
|
||||
plugins: [react()],
|
||||
})
|
||||
8
code/requirements.txt
Normal file
8
code/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
openai
|
||||
vllm
|
||||
jsonlines
|
||||
faiss-cpu
|
||||
# faiss-gpu
|
||||
fastapi
|
||||
uvicorn
|
||||
yarl
|
||||
4
code/scripts/run.sh
Normal file
4
code/scripts/run.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
python ./src/generation/run.py \
|
||||
--model_path "openbmb/MiniCPM4-Survey" \
|
||||
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
|
||||
--output_file "test.md" \
|
||||
5
code/scripts/run_with_frontend.sh
Normal file
5
code/scripts/run_with_frontend.sh
Normal file
@@ -0,0 +1,5 @@
|
||||
python ./src/generation/run.py \
|
||||
--model_path "openbmb/MiniCPM4-Survey" \
|
||||
--query "Please design a survey that assesses the performance of language processing systems on unseen data to measure their robustness in natural language understanding tasks." \
|
||||
--output_file "test.md" \
|
||||
--port 8001 \
|
||||
810
code/src/generation/buffer.py
Normal file
810
code/src/generation/buffer.py
Normal file
@@ -0,0 +1,810 @@
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
|
||||
# BASE_SURVEY_STRUCTURE = """
|
||||
# # Title: A survey of ...
|
||||
# # Introduction: None.
|
||||
# # Section 1: None.
|
||||
# ## Subsection 1 (if needed): None.
|
||||
# ## Subsection 2 (if needed): None.
|
||||
# ### Subsubsection 1 (if needed): None.
|
||||
# ### Subsubsection 2 (if needed): None.
|
||||
# ### ...
|
||||
# # Section 2: None.
|
||||
# # ...
|
||||
# # Conclusion: None.
|
||||
# """
|
||||
|
||||
|
||||
class SurveyManager:
|
||||
BASE_SURVEY_STRUCTURE = {
|
||||
"title": "",
|
||||
"abstract": "",
|
||||
"introduction": {
|
||||
"content": ""
|
||||
},
|
||||
"sections": [],
|
||||
"conclusion": ""
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parse_update_pos(update_pos):
|
||||
"""
|
||||
(1) "title", "abstract", "introduction", or "conclusion"
|
||||
(2) "section-i/subsection-j/..."
|
||||
|
||||
"""
|
||||
if update_pos in ["title", "abstract", "introduction", "conclusion","plan"]:
|
||||
return update_pos
|
||||
else:
|
||||
keys = update_pos.split("/")
|
||||
if len(keys) == 1: # Section-?
|
||||
i = int(keys[0].lower().split("section-")[-1])
|
||||
return f"section-{i}"
|
||||
elif len(keys) == 2: # Section-?/Subsection-?
|
||||
i = int(keys[0].lower().split("section-")[-1])
|
||||
j = int(keys[1].lower().split("subsection-")[-1])
|
||||
return f"section-{i}/subsection-{j}"
|
||||
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
|
||||
i = int(keys[0].lower().split("section-")[-1])
|
||||
j = int(keys[1].lower().split("subsection-")[-1])
|
||||
k = int(keys[2].lower().split("subsubsection-")[-1])
|
||||
return f"section-{i}/subsection-{j}/subsubsection-{k}"
|
||||
else:
|
||||
raise ValueError("unsupported update_pos keys")
|
||||
|
||||
@staticmethod
|
||||
def _to_one_line(string):
|
||||
if isinstance(string, dict):
|
||||
if "content" in string and string["content"]:
|
||||
return SurveyManager._to_one_line(string["content"])
|
||||
# return SurveyManager._to_one_line(string["content"])
|
||||
else:
|
||||
return "[PLAN] " + string.get("plan", "").replace("\n", " ").strip()
|
||||
if not string:
|
||||
return ""
|
||||
else:
|
||||
return string#.replace("\n", " ")
|
||||
|
||||
@staticmethod
|
||||
def convert_survey_dict_to_str(current_survey):
|
||||
string = ""
|
||||
if current_survey == {}:
|
||||
return "There is no survey."
|
||||
# title
|
||||
try:
|
||||
content = SurveyManager._to_one_line(current_survey["title"])
|
||||
string += f"# {content}\n"
|
||||
except:
|
||||
string += f"# Title: None\n"
|
||||
|
||||
# abstract
|
||||
try:
|
||||
content = SurveyManager._to_one_line(current_survey["abstract"])
|
||||
string += f"## Abstract\n{content}\n"
|
||||
except:
|
||||
string += f"## Abstract\nNone\n"
|
||||
|
||||
# introduction
|
||||
try:
|
||||
content = SurveyManager._to_one_line(current_survey["introduction"])
|
||||
string += f"## Introduction\n{content}\n"
|
||||
except:
|
||||
string += f"## Introduction\nNone\n"
|
||||
|
||||
# sections
|
||||
if "sections" in current_survey:
|
||||
for i, section in enumerate(current_survey["sections"]):
|
||||
title_key = "name" if "name" in section else "title"
|
||||
name, content = section[title_key], SurveyManager._to_one_line(section)
|
||||
# string += f"# Section-{i+1} [{name}]: {content}\n"
|
||||
string += f"## {name}\n{content}\n"
|
||||
|
||||
if "subsections" in section:
|
||||
for j, subsection in enumerate(section["subsections"]):
|
||||
name, content = subsection[title_key], SurveyManager._to_one_line(subsection)
|
||||
# string += f" ## Subsection-{j+1} [{name}]: {content}\n"
|
||||
string += f"### {name}\n{content}\n"
|
||||
|
||||
if "subsubsections" in subsection:
|
||||
for k, subsubsection in enumerate(subsection["subsubsections"]):
|
||||
name, content = subsubsection[title_key], SurveyManager._to_one_line(subsubsection)
|
||||
# string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
|
||||
string += f"#### {name}\n{content}\n"
|
||||
|
||||
|
||||
# conclusion
|
||||
try:
|
||||
content = SurveyManager._to_one_line(current_survey["conclusion"])
|
||||
string += f"## Conclusion\n{content}\n"
|
||||
except:
|
||||
string += f"## Conclusion:\nNone\n"
|
||||
|
||||
return string
|
||||
|
||||
@staticmethod
|
||||
def _abbr_one_line(string, abbr=True):
|
||||
if isinstance(string, dict):
|
||||
if "content" in string and string["content"]:
|
||||
return SurveyManager._abbr_one_line(string["content"], abbr=abbr)
|
||||
elif "plan" in string:
|
||||
return "[PLAN] " + string["plan"].replace("\n", " ").strip()
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
if not string:
|
||||
return ""
|
||||
else:
|
||||
if abbr and len(string) > 50:
|
||||
return "[OK] " + string.replace("\n", " ").strip()[:50] + "..."
|
||||
else:
|
||||
return "[OK] " + string.replace("\n", " ").strip()
|
||||
|
||||
@staticmethod
|
||||
def convert_survey_dict_to_abbr_str(current_survey):
|
||||
string = ""
|
||||
if current_survey == {}:
|
||||
return "There is no survey."
|
||||
# title
|
||||
try:
|
||||
content = SurveyManager._abbr_one_line(current_survey["title"], abbr=False)
|
||||
string += f"# Title: {content}\n"
|
||||
except:
|
||||
string += f"# Title: None\n"
|
||||
# abstract
|
||||
try:
|
||||
content = SurveyManager._abbr_one_line(current_survey["abstract"], abbr=False)
|
||||
string += f"# Abstract: {content}\n"
|
||||
except:
|
||||
string += f"# Abstract: None\n"
|
||||
|
||||
# introduction
|
||||
try:
|
||||
content = SurveyManager._abbr_one_line(current_survey["introduction"])
|
||||
string += f"# Introduction: {content}\n"
|
||||
except:
|
||||
string += f"# Introduction: None\n"
|
||||
|
||||
# sections
|
||||
if "sections" in current_survey:
|
||||
for i, section in enumerate(current_survey["sections"]):
|
||||
title_key = "name" if "name" in section else "title"
|
||||
name, content = section[title_key], SurveyManager._abbr_one_line(section)
|
||||
string += f"# Section-{i+1} [{name}]: {content}\n"
|
||||
|
||||
if "subsections" in section:
|
||||
for j, subsection in enumerate(section["subsections"]):
|
||||
name, content = subsection[title_key], SurveyManager._abbr_one_line(subsection)
|
||||
string += f" ## Subsection-{j+1} [{name}]: {content}\n"
|
||||
|
||||
if "subsubsections" in subsection:
|
||||
for k, subsubsection in enumerate(subsection["subsubsections"]):
|
||||
name, content = subsubsection[title_key], SurveyManager._abbr_one_line(subsubsection)
|
||||
string += f" ### Subsubsection-{k+1} [{name}]: {content}\n"
|
||||
|
||||
# conclusion
|
||||
try:
|
||||
content = SurveyManager._abbr_one_line(current_survey["conclusion"])
|
||||
string += f"# Conclusion: {content}\n"
|
||||
except:
|
||||
string += f"# Conclusion: None\n"
|
||||
|
||||
return string
|
||||
|
||||
@staticmethod
|
||||
def update_one_section(sections, i, content):
|
||||
# i -= 1
|
||||
if i >= 0 and i <= (len(sections)-1):
|
||||
sections[i]["content"] = content
|
||||
return True
|
||||
else:
|
||||
# print("update fail!")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_current_survey(current_survey, answer) -> bool:
|
||||
"""
|
||||
update_pos: "section-i/subsection-j/subsubsection-k"
|
||||
"""
|
||||
# if answer == {}:
|
||||
# return True
|
||||
try:
|
||||
update_pos, content = answer["update"], answer["content"]
|
||||
|
||||
if update_pos == "plan":
|
||||
# current_survey = content
|
||||
if current_survey == {}:
|
||||
for k,v in content.items():
|
||||
current_survey[k] = copy.deepcopy(v)
|
||||
else:
|
||||
return False
|
||||
elif update_pos in ["conclusion", "abstract"]:
|
||||
if update_pos not in current_survey:
|
||||
# print("update fail!")
|
||||
return False
|
||||
current_survey[update_pos] = content
|
||||
|
||||
elif update_pos == "introduction":
|
||||
if update_pos not in current_survey:
|
||||
# print("update fail!")
|
||||
return False
|
||||
current_survey[update_pos] = {"content": content}
|
||||
|
||||
else:
|
||||
keys = update_pos.split("/")
|
||||
if len(keys) == 1: # Section-?
|
||||
i = int(keys[0].lower().split("section-")[-1])-1
|
||||
return SurveyManager.update_one_section(current_survey["sections"], i, content)
|
||||
|
||||
elif len(keys) == 2: # Section-?/Subsection-?
|
||||
i = int(keys[0].lower().split("section-")[-1])-1
|
||||
j = int(keys[1].lower().split("subsection-")[-1])-1
|
||||
try:
|
||||
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"], j, content)
|
||||
except:
|
||||
# print("update fail!")
|
||||
return False
|
||||
|
||||
elif len(keys) == 3: # Section-?/Subsection-?/Subsubsection-?
|
||||
i = int(keys[0].lower().split("section-")[-1])-1
|
||||
j = int(keys[1].lower().split("subsection-")[-1])-1
|
||||
k = int(keys[2].lower().split("subsubsection-")[-1])-1
|
||||
try:
|
||||
return SurveyManager.update_one_section(current_survey["sections"][i]["subsections"][j]["subsubsections"], k, content) # 禁用第四级
|
||||
except:
|
||||
# print("update fail!")
|
||||
return False
|
||||
else:
|
||||
# print("update fail!")
|
||||
# print("unsupported update_pos keys")
|
||||
return False
|
||||
# raise ValueError("unsupported update_pos keys")
|
||||
except:
|
||||
# print("update fail!")
|
||||
return False
|
||||
# print("answer is not a valid json object.")
|
||||
# print(answer)
|
||||
# raise ValueError("answer is not a valid json object.")
|
||||
return True
|
||||
|
||||
|
||||
from prompts import *
|
||||
class PromptManger:
|
||||
system_prompt = SYSTEM_PROMPT_0415_BUFFER
|
||||
user_prompt_v0 = USER_PROMPT_v0_0424_BUFFER
|
||||
user_prompt = USER_PROMPT_0415_BUFFER
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class BufferManager:
|
||||
"""
|
||||
Used to manage prompts/responses generated during the Rollout phase, providing data support for subsequent training.
|
||||
batch_rollout_data = [
|
||||
{
|
||||
query (or env_id): # Uniquely identifies a query or environment, [input parameter].
|
||||
*running_id: # Uniquely identifies a single rollout. For cases where a query or environment is repeated multiple times, the query can be the same, but running_id will not repeat.
|
||||
state: { # Indicates whether the process is finished.
|
||||
"score": 0.0,
|
||||
"done": True / False
|
||||
"current_survey": dict # Structured data.
|
||||
}
|
||||
trajectory: [ # Organizes all data into a multi-turn interaction format.
|
||||
{
|
||||
step: int, 0~?, # The first step, usually includes some init_info or plan.
|
||||
original_response: str, The raw output from the model, which may have various formatting issues.
|
||||
answer_thought: str, # Encapsulated using the <think>...</think> block.
|
||||
answer: {
|
||||
"original_str": str
|
||||
"update": str,
|
||||
"name": str,
|
||||
"content": str,
|
||||
"inclusions": list, # Extracted independently?
|
||||
}
|
||||
tool_call_thought: str, # Encapsulated using the <think>...</think> block.
|
||||
tool_call: {
|
||||
"original_str": str, # Encapsulated using the <tool_call>...</tool_call> block, used for tool invocation. In the survey setting, it is either "done" to end the task or "search".
|
||||
"tool_name": str # done or search.
|
||||
"keywords": list[str], Extracted search keywords from tool_call, otherwise none.
|
||||
}
|
||||
*papers: list[str], # Top-n papers retrieved via the search engine. Required if using the Agent-Summary-1 for collaborative optimization; otherwise, not needed.
|
||||
cites: list[str], # References cited by the model, which may include multiple citations.
|
||||
summarys: list[str], # Summaries of papers generated using Agent-Summary-1. Must include BIBKEY.
|
||||
*prompt_for_generator: str, # The prompt input to the generator at the current step. Required if using Agent-Summary-2 for generation and collaborative optimization; otherwise, not needed.
|
||||
},
|
||||
...
|
||||
|
||||
]
|
||||
|
||||
},
|
||||
...
|
||||
]
|
||||
|
||||
"""
|
||||
def __init__(self, prompts, repeat_n: int=1):
|
||||
# self.config = config
|
||||
self.step = 0
|
||||
self.batch_rollout_data = []
|
||||
self.running_ids = [] # active envs
|
||||
batch_size = prompts.batch['input_ids'].size(0)
|
||||
uids = prompts.non_tensor_batch['uid']
|
||||
querys = prompts.non_tensor_batch['raw_prompt'].copy()
|
||||
ground_truths = prompts.non_tensor_batch['ground_truth']
|
||||
# print(querys)
|
||||
new_querys = []
|
||||
for i_batch in range(batch_size):
|
||||
raw_prompt_i_batch = querys[i_batch][-1]["content"]
|
||||
new_querys.append(raw_prompt_i_batch)
|
||||
querys = new_querys
|
||||
|
||||
assert len(querys) == len(uids)
|
||||
for query, uid, ground_truth in zip(querys, uids, ground_truths):
|
||||
|
||||
now_survey = {}
|
||||
|
||||
for _ in range(repeat_n):
|
||||
self.batch_rollout_data.append({
|
||||
"query": query,
|
||||
"uid": uid,
|
||||
"state": {
|
||||
# "score": 0.0, # only for debug
|
||||
# "format_score": None, # will update at last step
|
||||
"done": False,
|
||||
"current_survey": {}
|
||||
},
|
||||
"trajectory": [],
|
||||
"history_messages": [],
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def _build_system_prompt():
|
||||
prompt = PromptManger.system_prompt
|
||||
return prompt
|
||||
@staticmethod
|
||||
def _build_user_prompt_v0(query, current_survey):
|
||||
# query
|
||||
prompt = PromptManger.user_prompt_v0.replace("<user_query>", query)
|
||||
|
||||
# add template
|
||||
prompt = prompt.replace("<init_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _build_user_prompt(query, current_survey, trajs):
|
||||
last_traj = trajs[-1]
|
||||
# query
|
||||
prompt = PromptManger.user_prompt.replace("<user_query>", query)
|
||||
|
||||
# add current survey
|
||||
prompt = prompt.replace("<current_survey>", SurveyManager.convert_survey_dict_to_abbr_str(current_survey))
|
||||
|
||||
# current plan
|
||||
if last_traj["tool_call_thought"] == "":
|
||||
prompt = prompt.replace("<last_step_thought>", "Your last thought is not available, please give new plan")
|
||||
else:
|
||||
prompt = prompt.replace("<last_step_thought>", last_traj["tool_call_thought"])
|
||||
prompt = prompt.replace("<last_step_tool_call>", json.dumps(last_traj["tool_call"]))
|
||||
|
||||
# summarys
|
||||
for traj in reversed(trajs):
|
||||
if len(traj["summarys"]) > 0:
|
||||
break
|
||||
summary_num = len(traj["summarys"])
|
||||
|
||||
if summary_num == 0:
|
||||
prompt = prompt.replace("<summarys>", "There is no result.")
|
||||
else:
|
||||
prompt = prompt.replace("<summarys>", f"There are {summary_num} results:\n\n" + "\n\n".join(traj["summarys"]))
|
||||
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def _build_user_prompt_force_correct(query, current_survey, trajs):
|
||||
if current_survey == {}:
|
||||
# gen plan
|
||||
now_section = "plan"
|
||||
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
|
||||
else:
|
||||
now_section = ""
|
||||
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
|
||||
now_section = "abstract"
|
||||
elif "content" not in current_survey["introduction"]:
|
||||
now_section = "introduction"
|
||||
elif "sections" in current_survey:
|
||||
for section in current_survey["sections"]:
|
||||
if "content" not in section:
|
||||
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
|
||||
break
|
||||
elif "subsections" in section:
|
||||
for subsection in section["subsections"]:
|
||||
if "content" not in subsection:
|
||||
now_section = "section-{}/subsection-{}".format(
|
||||
current_survey["sections"].index(section) + 1,
|
||||
section["subsections"].index(subsection) + 1
|
||||
)
|
||||
break
|
||||
elif "subsubsections" in subsection:
|
||||
for subsubsection in subsection["subsubsections"]:
|
||||
if "content" not in subsubsection:
|
||||
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
|
||||
current_survey["sections"].index(section) + 1,
|
||||
section["subsections"].index(subsection) + 1,
|
||||
subsection["subsubsections"].index(subsubsection) + 1
|
||||
)
|
||||
break
|
||||
if now_section:
|
||||
break
|
||||
if now_section:
|
||||
break
|
||||
|
||||
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
|
||||
now_section = "conclusion"
|
||||
else:
|
||||
trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
|
||||
if now_section != "":
|
||||
trajs[-1]["tool_call_thought"] = f"Next I will provide {now_section}"
|
||||
for traj in reversed(trajs):
|
||||
if len(traj["summarys"]) > 0:
|
||||
break
|
||||
summary_num = len(traj["summarys"])
|
||||
if now_section == "plan" and summary_num == 0:
|
||||
trajs[-1]["tool_call_thought"] = "I need to get enough information."
|
||||
|
||||
return BufferManager._build_user_prompt(query, current_survey, trajs)
|
||||
|
||||
@staticmethod
|
||||
def _check_finalize(query, current_survey, trajs):
|
||||
if current_survey == {}:
|
||||
# gen plan
|
||||
return False
|
||||
# trajs[-1]["tool_call_thought"] = "Next I will provide the plan. "
|
||||
else:
|
||||
now_section = ""
|
||||
if isinstance(current_survey["abstract"],dict) and "content" not in current_survey["abstract"]:
|
||||
now_section = "abstract"
|
||||
elif "content" not in current_survey["introduction"]:
|
||||
now_section = "introduction"
|
||||
elif "sections" in current_survey:
|
||||
for section in current_survey["sections"]:
|
||||
if "content" not in section:
|
||||
now_section = "section-{}".format(current_survey["sections"].index(section) + 1)
|
||||
break
|
||||
elif "subsections" in section:
|
||||
for subsection in section["subsections"]:
|
||||
if "content" not in subsection:
|
||||
now_section = "section-{}/subsection-{}".format(
|
||||
current_survey["sections"].index(section) + 1,
|
||||
section["subsections"].index(subsection) + 1
|
||||
)
|
||||
break
|
||||
elif "subsubsections" in subsection:
|
||||
for subsubsection in subsection["subsubsections"]:
|
||||
if "content" not in subsubsection:
|
||||
now_section = "section-{}/subsection-{}/subsubsection-{}".format(
|
||||
current_survey["sections"].index(section) + 1,
|
||||
section["subsections"].index(subsection) + 1,
|
||||
subsection["subsubsections"].index(subsubsection) + 1
|
||||
)
|
||||
break
|
||||
if now_section:
|
||||
break
|
||||
if now_section:
|
||||
break
|
||||
|
||||
elif isinstance(current_survey["conclusion"],dict) and "content" not in current_survey["conclusion"]:
|
||||
now_section = "conclusion"
|
||||
# else:
|
||||
# trajs[-1]["tool_call_thought"] = "Next I will finalize the survey."
|
||||
if now_section != "":
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# rule-based method: query, plan, paragraphs -> prompt -> thought, paragraph, action
|
||||
def build_prompt_for_generator(self):
|
||||
total_messages = []
|
||||
self.running_ids = []
|
||||
for running_id, data in enumerate(self.batch_rollout_data):
|
||||
if data["state"]["done"]:
|
||||
pass
|
||||
else:
|
||||
if len(data["trajectory"]) == 0: # first prompt
|
||||
user_prompt = BufferManager._build_user_prompt_v0(data["query"],
|
||||
data["state"]["current_survey"])
|
||||
else:
|
||||
if data["trajectory"][-1]["update_success"]:
|
||||
user_prompt = BufferManager._build_user_prompt(data["query"],
|
||||
data["state"]["current_survey"],
|
||||
data["trajectory"])
|
||||
else:
|
||||
# user_prompt = data["history_messages"][-1][1]["content"]
|
||||
user_prompt = BufferManager._build_user_prompt_force_correct(data["query"],
|
||||
data["state"]["current_survey"],
|
||||
data["trajectory"])
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": BufferManager._build_system_prompt(),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
]
|
||||
data["history_messages"].append(messages)
|
||||
total_messages.append(messages)
|
||||
self.running_ids.append(running_id) # update running ids
|
||||
return total_messages
|
||||
|
||||
def update_all_scores(self, scores):
|
||||
assert len(scores) == len(self.batch_rollout_data)
|
||||
for score, log in zip(scores, self.batch_rollout_data):
|
||||
log["state"]["score"] = score
|
||||
|
||||
def update_all_format_scores(self, scores):
|
||||
assert len(scores) == len(self.batch_rollout_data)
|
||||
for score, log in zip(scores, self.batch_rollout_data):
|
||||
log["state"]["format_score"] = score
|
||||
|
||||
|
||||
def update_trajectory(self, model_responses, env_feedbacks):
|
||||
"""
|
||||
model_response: original_response, thought, paragraph, tool_call, format_reward
|
||||
env_feedback: done, search_keywards, abstracts, outcome_reward
|
||||
"""
|
||||
assert len(self.running_ids) == len(model_responses)
|
||||
assert len(self.running_ids) == len(env_feedbacks)
|
||||
|
||||
for running_id, response, feedback in zip(self.running_ids, model_responses, env_feedbacks):
|
||||
# update state
|
||||
self.batch_rollout_data[running_id]["state"]["done"] = feedback["done"] # if True, finalize the task
|
||||
|
||||
update_success = False
|
||||
if response["true"]:
|
||||
if self.batch_rollout_data[running_id]["state"]["current_survey"] != {}:
|
||||
if len(response["answer"]) != 0: # no empty dict or start
|
||||
update_success = SurveyManager.update_current_survey(
|
||||
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||
response["answer"])
|
||||
else:
|
||||
# Search Then Write
|
||||
if len(response["answer"]) != 0 and "There is no result" not in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"]:
|
||||
update_success = SurveyManager.update_current_survey(
|
||||
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||
response["answer"])
|
||||
elif "There is no result" in self.batch_rollout_data[running_id]["history_messages"][-1][1]["content"] and len(response["answer"]) == 0:
|
||||
update_success = True
|
||||
|
||||
|
||||
self.batch_rollout_data[running_id]["trajectory"].append({
|
||||
"step": self.step,
|
||||
"original_response": response["original_response"],
|
||||
"answer_thought": response["answer_thought"],
|
||||
"answer": response["answer"],
|
||||
"tool_call_thought": response["tool_call_thought"],
|
||||
"tool_call": response["tool_call"],
|
||||
"search_keywords": feedback["search_keywords"],
|
||||
"summarys": feedback["summarys"],
|
||||
"update_success": update_success and response["true"],
|
||||
})
|
||||
|
||||
|
||||
self.batch_rollout_data[running_id]["history_messages"][-1].append({
|
||||
"role": "assistant",
|
||||
"content": response["original_response"],
|
||||
})
|
||||
|
||||
if self.batch_rollout_data[running_id]["state"]["done"]:
|
||||
real_done = BufferManager._check_finalize(self.batch_rollout_data[running_id]["query"],
|
||||
self.batch_rollout_data[running_id]["state"]["current_survey"],
|
||||
self.batch_rollout_data[running_id]["trajectory"])
|
||||
if not real_done:
|
||||
self.batch_rollout_data[running_id]["state"]["done"] = False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def match_reference(text:str):
|
||||
reg = r"\\\w*cite(?!style)\w*\{(.+?)\}"
|
||||
placeholder_reg = re.compile(r"^#\d+$")
|
||||
reg_bibkeys = re.findall(reg, text)
|
||||
bibkeys = set()
|
||||
for bibkey in reg_bibkeys:
|
||||
single_bib = bibkey.split(",")
|
||||
for bib in single_bib:
|
||||
if not placeholder_reg.match(bib):
|
||||
bib = bib.strip()
|
||||
if bib and bib != "*":
|
||||
bibkeys.add(bib)
|
||||
|
||||
reg = r"\\nocite{(.+?)\}"
|
||||
reg_bibkeys = re.findall(reg, text)
|
||||
for bibkey in reg_bibkeys:
|
||||
single_bib = bibkey.split(",")
|
||||
for bib in single_bib:
|
||||
if not placeholder_reg.match(bib):
|
||||
bib = bib.strip()
|
||||
if bib and bib != "*":
|
||||
bibkeys.remove(bib)
|
||||
|
||||
ref_key_list = list(bibkeys)
|
||||
return ref_key_list
|
||||
|
||||
@staticmethod
|
||||
def parse_generator_response(response):
|
||||
"""
|
||||
1. 解析失败: step + 1, 重新生成, 给出提示
|
||||
2. 解析成功:
|
||||
2.1 tool_call == search(keywords) 发送post请求
|
||||
2.2 tool_call == done 结束任务
|
||||
|
||||
**standard format**
|
||||
|
||||
Current Update:
|
||||
<think> [Your Thoughts]: str </think>
|
||||
<answer> {"update": str, "content": str}: dict </answer>
|
||||
|
||||
Next Plan:
|
||||
<think> [Your Thoughts]: str </think>
|
||||
<tool_call> {"tool": "search", "arguments": {}}: dict</tool_call>
|
||||
"""
|
||||
extracted_result = {
|
||||
"original_response": response
|
||||
}
|
||||
|
||||
try:
|
||||
current_update = response.split("Current Update:")[-1].split("Next Plan:")[0]
|
||||
except:
|
||||
current_update = response
|
||||
|
||||
# pattern
|
||||
think_pattern = r"<think>(.*?)</think>"
|
||||
answer_pattern = r"<answer>(.*?)</answer>"
|
||||
tool_pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
|
||||
# extract information from current_update
|
||||
|
||||
think_match = re.search(think_pattern, current_update, re.DOTALL) # 多行提取
|
||||
if think_match:
|
||||
think = think_match.group(1)
|
||||
think = think.strip()
|
||||
else:
|
||||
think = ""
|
||||
extracted_result["answer_thought"] = think
|
||||
|
||||
answer_match = re.search(answer_pattern, current_update, re.DOTALL) # 多行提取
|
||||
has_answer = False
|
||||
if answer_match:
|
||||
answer = answer_match.group(1)
|
||||
answer = answer.strip()
|
||||
try:
|
||||
answer = json.loads(answer)
|
||||
if not answer == {}:
|
||||
assert isinstance(answer["update"], str)
|
||||
answer["update"] = SurveyManager.parse_update_pos(answer["update"])
|
||||
if answer["update"] == "plan":
|
||||
|
||||
assert isinstance(answer["content"], dict)
|
||||
plan = answer["content"]
|
||||
assert isinstance(plan, dict)
|
||||
plan.pop("instruction",None)
|
||||
keys = ["abstract", "introduction", "conclusion","sections","title"]
|
||||
for key in keys:
|
||||
assert key in plan
|
||||
for key in plan:
|
||||
assert key in keys
|
||||
if key == "sections":
|
||||
assert isinstance(plan[key], list)
|
||||
for section in plan[key]:
|
||||
assert isinstance(section, dict)
|
||||
assert "plan" in section
|
||||
assert "title" in section
|
||||
assert isinstance(section["plan"], str)
|
||||
assert isinstance(section["title"], str)
|
||||
assert section["title"] != "Methodology" # 不能是Methodology,WIP
|
||||
if "subsections" in section:
|
||||
assert isinstance(section["subsections"], list)
|
||||
for subsection in section["subsections"]:
|
||||
assert isinstance(subsection, dict)
|
||||
assert "plan" in subsection
|
||||
assert "title" in subsection
|
||||
assert isinstance(subsection["plan"], str)
|
||||
assert isinstance(subsection["title"], str)
|
||||
if "subsubsections" in section:
|
||||
assert isinstance(subsection["subsubsections"], list)
|
||||
for subsubsection in subsection["subsubsections"]:
|
||||
assert isinstance(subsubsection, dict)
|
||||
assert "plan" in subsubsection
|
||||
assert "title" in subsubsection
|
||||
assert isinstance(subsubsection["plan"], str)
|
||||
assert isinstance(subsubsection["title"], str)
|
||||
elif key == "title":
|
||||
assert isinstance(plan[key], str)
|
||||
else:
|
||||
assert isinstance(plan[key], dict)
|
||||
assert "plan" in plan[key]
|
||||
if key not in ["abstract", "conclusion", "introduction"]:
|
||||
assert "title" in plan[key]
|
||||
else:
|
||||
assert isinstance(answer["content"], str)
|
||||
has_answer = True
|
||||
except:
|
||||
answer = {}
|
||||
else:
|
||||
answer = {}
|
||||
extracted_result["answer"] = answer
|
||||
|
||||
# extract information from next_plan
|
||||
|
||||
try:
|
||||
next_plan = response.split("Next Plan:")[1]
|
||||
except:
|
||||
try:
|
||||
next_plan = response.split("</answer>")[1]
|
||||
except:
|
||||
next_plan = response
|
||||
|
||||
think_match = re.search(think_pattern, next_plan, re.DOTALL) # 多行提取
|
||||
if think_match:
|
||||
think = think_match.group(1)
|
||||
think = think.strip()
|
||||
else:
|
||||
think = ""
|
||||
extracted_result["tool_call_thought"] = think
|
||||
|
||||
tool_match = re.search(tool_pattern, next_plan, re.DOTALL) # 多行提取
|
||||
has_tool_call = False
|
||||
if tool_match:
|
||||
tool_text = tool_match.group(1)
|
||||
tool_text = tool_text.strip()
|
||||
try:
|
||||
tool_call = json.loads(tool_text)
|
||||
assert tool_call["name"] in ["search_engine", "finalize"]
|
||||
if tool_call["name"] == "search_engine":
|
||||
assert isinstance(tool_call["arguments"]["query"], list)
|
||||
has_tool_call = True
|
||||
except:
|
||||
tool_call = {}
|
||||
else:
|
||||
|
||||
tool_call = {}
|
||||
|
||||
extracted_result["tool_call"] = tool_call
|
||||
|
||||
extracted_result["true"] = has_answer and has_tool_call
|
||||
reg = r"[\u4e00-\u9fa5]"
|
||||
has_chinese = re.search(reg, response) is not None
|
||||
extracted_result["true"] = extracted_result["true"] and not has_chinese
|
||||
|
||||
return extracted_result
|
||||
|
||||
|
||||
class BufferManager_V2(BufferManager):
|
||||
|
||||
def __init__(self, querys, repeat_n=1):
|
||||
# self.config = config
|
||||
self.step = 0
|
||||
self.batch_rollout_data = []
|
||||
self.running_ids = [] # active envs
|
||||
|
||||
for uid, query in enumerate(querys):
|
||||
print("CURRENT QUERY: ", query)
|
||||
for _ in range(repeat_n):
|
||||
self.batch_rollout_data.append({
|
||||
"query": query,
|
||||
"uid": f"query_{uid}",
|
||||
"state": {
|
||||
# "score": 0.0, # only for debug
|
||||
# "format_score": None, # will update at last step
|
||||
"done": False,
|
||||
"current_survey": {}
|
||||
},
|
||||
"trajectory": [],
|
||||
"history_messages": []
|
||||
})
|
||||
|
||||
81
code/src/generation/prompts.py
Normal file
81
code/src/generation/prompts.py
Normal file
@@ -0,0 +1,81 @@
|
||||
|
||||
|
||||
SYSTEM_PROMPT_0415_BUFFER = """You are a survey writer. You are asked to write a survey follow the instruction, refered as "Query" or "User's Query". You will finish the survey by multi-step updating.
|
||||
|
||||
Usually, you need to do two things:
|
||||
(1) First, you need to update the survey using the retrieved information according to the current plan, refered as "Current Update". You MUST think inside <think>...</think> before you give your <answer>...</answer> action, mainly about "How to write paragraphs with citations based on retrieved information to complete the current plan?". If the current plan is None, or you think the current plan is not good, or you think the retrieved information is not enough for you to finish the plan, you can jump the "Answer" action by giving "{}" as answer. Please give the citation in \\cite{}.
|
||||
|
||||
(2) Then, you need decide what part of the survey needs to be updated, refered as "Next Plan". You MUST think inside <think>...</think> before you give your <tool_call>...</tool_call> action. If you think the current retrieved information is enough to finish your next plan, you can jump the "Tool Call" action by giving "{}" as tool call.
|
||||
|
||||
## Answer
|
||||
You can give one answer to update the survey.
|
||||
<answer>
|
||||
{"update": <section-pos>, "content": paragraph }
|
||||
</answer>
|
||||
|
||||
There are two parameters in <answer> action.
|
||||
* update: string, which position you want to update, such as "title", "abstract", "introduction", "section-1", "section-1/subsction-1", "section-1/subsction-1/subsection-1", and "conclusion".
|
||||
* content: string, the update content for the position of the survey, please give the faithful citation in \\cite{}. . Or dict, only when you give the plan of the paper, the values including the section title and a simple plan of it.
|
||||
|
||||
## Tool Call
|
||||
You can call one function to assist the survey writing.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type": "function", "function": {"name": "search_engine", "description": "Search reasearch papers.", "parameters": {"type": "object", "properties": {"query": {"type": "array", "items": {"type": "string"}, "description": "The words to search for in quotes."}, "required": ["query"]}}}}
|
||||
{"type": "function", "function": {"name": "finalize", "description": "Finalize the survey.", "parameters": {}, "required": []}}
|
||||
</tools>
|
||||
|
||||
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
||||
<tool_call> {"name": <function-name>, "arguments": <args-json-object>} </tool_call>
|
||||
|
||||
For example, You can call the search engine by using:
|
||||
<tool_call> {"name": "search_engine", "arguments": {"query": ["keyword-1", "keyword-2", ...]} </tool_call>
|
||||
If you think the survey is finished, please call:
|
||||
<tool_call> {"name": "finalize", "arguments": {} </tool_call>
|
||||
|
||||
** Attention **
|
||||
You must use correct JSON format inside <answer>...</answer> and <tool_call>...</tool_call>, otherwise we can't extract the corrent content.
|
||||
|
||||
**Output format**
|
||||
(1) Current Update:
|
||||
<think> How to write paragraphs with citations based on retrieved information to complete the plan? </think>
|
||||
<answer> Please provide your answer here. (JSON format) </answer>
|
||||
(2) Next Plan:
|
||||
<think> Which part of the survey needs to be updated? What information needs to be queried? </think>
|
||||
<tool_call> Please call a tool here. (JSON format) </tool_call>
|
||||
"""
|
||||
|
||||
USER_PROMPT_v0_0424_BUFFER = """Please update the survey depending on the insturctions.
|
||||
**User's Query**
|
||||
<user_query>
|
||||
|
||||
**Current Survey**
|
||||
<init_survey>
|
||||
|
||||
**Current Plan**
|
||||
<think>I need to get enough information.</think>
|
||||
<tool_call>{}</tool_call>
|
||||
|
||||
**Retrieved Information**
|
||||
There is no results.
|
||||
|
||||
Please give your response following the output format.
|
||||
"""
|
||||
|
||||
USER_PROMPT_0415_BUFFER = """Please update the survey depending on the insturctions.
|
||||
**User's Query**
|
||||
<user_query>
|
||||
|
||||
**Current Survey**
|
||||
<current_survey>
|
||||
|
||||
**Current Plan**
|
||||
<think> <last_step_thought> </think>
|
||||
<tool_call> <last_step_tool_call> </tool_call>
|
||||
|
||||
**Retrieved Information**
|
||||
<summarys>
|
||||
|
||||
Please give your response following the output format.
|
||||
"""
|
||||
338
code/src/generation/run.py
Normal file
338
code/src/generation/run.py
Normal file
@@ -0,0 +1,338 @@
|
||||
from contextlib import contextmanager
|
||||
from codetiming import Timer
|
||||
@contextmanager
|
||||
def _timer(name: str, timing_raw):
|
||||
with Timer(name=name, logger=None) as timer:
|
||||
yield
|
||||
timing_raw[name] = timer.last
|
||||
|
||||
from buffer import SurveyManager
|
||||
from buffer import BufferManager_V2 as BufferManager
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
import re
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import asyncio
|
||||
import argparse
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import aiohttp
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 允许跨域(如果前端和后端端口不同需要加上)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
active_connections = set()
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
try:
|
||||
while True:
|
||||
await websocket.receive_text()
|
||||
except WebSocketDisconnect:
|
||||
active_connections.remove(websocket)
|
||||
|
||||
async def post_to_frontend(payload):
|
||||
print(f"Sending payload to frontend: {payload}") # Log the payload being sent
|
||||
for ws in list(active_connections):
|
||||
try:
|
||||
await ws.send_text(payload)
|
||||
except Exception as e:
|
||||
print(f"Error sending to WebSocket: {e}")
|
||||
active_connections.remove(ws)
|
||||
|
||||
|
||||
def write_to_json(data, path):
|
||||
with open(path, 'w', encoding='utf8') as f:
|
||||
f.write(json.dumps(data, ensure_ascii=False, indent=4))
|
||||
|
||||
class OriginalvLLMRollout:
|
||||
def __init__(self, model_name_or_path):
|
||||
# init vLLM
|
||||
self.rollout_model = LLM(
|
||||
model=model_name_or_path,
|
||||
tokenizer=model_name_or_path,
|
||||
gpu_memory_utilization=0.95,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
self.sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.8,
|
||||
repetition_penalty=1.05,
|
||||
top_k=20,
|
||||
max_tokens=2748,
|
||||
)
|
||||
|
||||
def generate(self, input_texts):
|
||||
generated_texts = []
|
||||
completions = self.rollout_model.generate(input_texts, self.sampling_params, use_tqdm=False)
|
||||
for output in completions:
|
||||
generated_text = output.outputs[0].text
|
||||
generated_texts.append(generated_text)
|
||||
return generated_texts
|
||||
|
||||
def chat(self, input_messages):
|
||||
generated_texts = []
|
||||
completions = self.rollout_model.chat(input_messages, self.sampling_params, use_tqdm=False)
|
||||
for output in completions:
|
||||
generated_text = output.outputs[0].text
|
||||
generated_texts.append(generated_text)
|
||||
return generated_texts
|
||||
|
||||
async def rollout_with_env(querys, batch_size, max_turns, model_path, url,
|
||||
deploy_port=None):
|
||||
"""
|
||||
Args:
|
||||
querys: [string]
|
||||
"""
|
||||
###############################
|
||||
#### splited by batch size ####
|
||||
###############################
|
||||
n = len(querys) // batch_size
|
||||
batch_querys = []
|
||||
for i in range(n+1):
|
||||
temp_data = querys[i*batch_size: (i+1)*batch_size]
|
||||
if len(temp_data) > 0:
|
||||
batch_querys.append(temp_data)
|
||||
print("QUERY NUMBER with BATCH: ", [len(x) for x in batch_querys])
|
||||
|
||||
###################
|
||||
#### init vllm ####
|
||||
###################
|
||||
vllm_manager = OriginalvLLMRollout(model_path)
|
||||
|
||||
############################
|
||||
#### init Format Reward ####
|
||||
############################
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
total_rollout_data = []
|
||||
for querys in batch_querys:
|
||||
###########################################
|
||||
#### acquire env configs and init envs ####
|
||||
###########################################
|
||||
buffer_manager = BufferManager(querys)
|
||||
|
||||
while True:
|
||||
# Break at max-turns
|
||||
if buffer_manager.step >= max_turns:
|
||||
break
|
||||
|
||||
###############################
|
||||
#### prepare input prompts ####
|
||||
###############################
|
||||
messagess_todo = buffer_manager.build_prompt_for_generator()
|
||||
# breakpoint()
|
||||
|
||||
# Break when no tasks
|
||||
if len(messagess_todo) == 0:
|
||||
break
|
||||
|
||||
##########################
|
||||
#### generate by vLLM ####
|
||||
##########################
|
||||
timing_raw = {}
|
||||
with _timer('vllm sampling', timing_raw):
|
||||
# response_texts = vllm_manager.chat(messagess_todo)
|
||||
response_texts = await asyncio.to_thread(vllm_manager.chat, messagess_todo)
|
||||
|
||||
##################################
|
||||
#### preprocess the responses ####
|
||||
##################################
|
||||
# 对response的详细处理可以集成到环境类中,因环境而异, 先对Response进行预处理
|
||||
extracted_results = []
|
||||
for response_text in response_texts:
|
||||
result = BufferManager.parse_generator_response(response_text)
|
||||
extracted_results.append(result)
|
||||
|
||||
#################################################
|
||||
#### execute in environment and get feedback ####
|
||||
#################################################
|
||||
payload = {
|
||||
"tool_calls": [x["tool_call"] for x in extracted_results]
|
||||
}
|
||||
if buffer_manager.step <=2:
|
||||
payload["topk"] = 20
|
||||
with _timer('get env feedback', timing_raw):
|
||||
# env_response_batched = requests.post(url, json=payload).json()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload) as resp:
|
||||
env_response_batched = await resp.json()
|
||||
|
||||
###################################
|
||||
#### postprocess the feedbacks ####
|
||||
###################################
|
||||
with _timer('postprocessing', timing_raw):
|
||||
buffer_manager.update_trajectory(extracted_results, env_response_batched)
|
||||
buffer_manager.step += 1
|
||||
|
||||
print(timing_raw)
|
||||
|
||||
if deploy_port is not None:
|
||||
now_text = json_to_markdown(buffer_manager.batch_rollout_data[-1])
|
||||
now_search_keywords= buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["search_keywords"]
|
||||
now_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["answer_thought"]
|
||||
next_update = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["tool_call_thought"]
|
||||
now_query = buffer_manager.batch_rollout_data[-1]["query"]
|
||||
trajs = buffer_manager.batch_rollout_data[-1]["trajectory"]
|
||||
updated_success = buffer_manager.batch_rollout_data[-1]["trajectory"][-1]["update_success"]
|
||||
if updated_success:
|
||||
for traj in reversed(trajs):
|
||||
if len(traj["summarys"]) > 0:
|
||||
break
|
||||
summary_num = len(traj["summarys"])
|
||||
if summary_num == 0:
|
||||
summary_text = "No summaries yet."
|
||||
else:
|
||||
summary_text = "\n".join(traj["summarys"])
|
||||
frontend_payload = {
|
||||
"markdown": now_text,
|
||||
"searchKeywords": now_search_keywords,
|
||||
"nowUpdate": now_update,
|
||||
"nextUpdate": next_update,
|
||||
"query": now_query,
|
||||
"papers": summary_text
|
||||
}
|
||||
frontend_payload = json.dumps(frontend_payload, ensure_ascii=False)
|
||||
try:
|
||||
await post_to_frontend(frontend_payload)
|
||||
except Exception as e:
|
||||
print(f"Error posting to frontend: {e}")
|
||||
|
||||
|
||||
|
||||
for item in buffer_manager.batch_rollout_data:
|
||||
item["survey_text"] = SurveyManager.convert_survey_dict_to_str(item["state"]["current_survey"])
|
||||
|
||||
total_rollout_data.extend(buffer_manager.batch_rollout_data)
|
||||
#####################################
|
||||
#### clear all envs and shutdown ####
|
||||
#####################################
|
||||
del buffer_manager
|
||||
|
||||
return total_rollout_data
|
||||
|
||||
|
||||
def json_to_markdown(json_data):
|
||||
text = SurveyManager.convert_survey_dict_to_str(json_data["state"]["current_survey"])
|
||||
all_summarys = {}
|
||||
for traj in json_data["trajectory"]:
|
||||
for item in traj["summarys"]:
|
||||
split_text = item.split("\n")
|
||||
bibkey = split_text[0].split(":")[1].strip()
|
||||
title_begin_index = item.find("Title:") + len("Title:")
|
||||
title_end_index = item.find("Abstract:")
|
||||
title = item[title_begin_index:title_end_index].strip()
|
||||
arxivid = bibkey.split("arxivid")[-1].strip()
|
||||
html = f"arxiv.org/abs/{arxivid}"
|
||||
all_summarys[bibkey] = f"[{title}](https://{html})"
|
||||
|
||||
reg = r"\\cite\{(.+?)\}"
|
||||
placeholder_reg = re.compile(r"^#\d+$")
|
||||
reg_bibkeys = re.findall(reg, text)
|
||||
bibkeys = []
|
||||
for bibkey in reg_bibkeys:
|
||||
single_bib = bibkey.split(",")
|
||||
for bib in single_bib:
|
||||
if not placeholder_reg.match(bib):
|
||||
bib = bib.strip()
|
||||
if bib and bib != "*" and bib not in bibkeys:
|
||||
bibkeys.append(bib)
|
||||
|
||||
bibkeys_index = {bibkey: i+1 for i, bibkey in enumerate(bibkeys)}
|
||||
|
||||
def replace_bibkey(bibkey):
|
||||
bibkey = bibkey.group(1)
|
||||
single_bib = bibkey.split(",")
|
||||
new_bibs = []
|
||||
for bib in single_bib:
|
||||
if not placeholder_reg.match(bib):
|
||||
bib = bib.strip()
|
||||
if bib and bib != "*":
|
||||
if bib in bibkeys_index:
|
||||
new_bibs.append(f"{bibkeys_index[bib]}")
|
||||
else:
|
||||
print(f"Warning: {bib} not found in bibkeys")
|
||||
if len(new_bibs) > 0:
|
||||
return "[" + ",".join(new_bibs) + "]"
|
||||
else:
|
||||
return ""
|
||||
text = re.sub(reg, replace_bibkey, text)
|
||||
reference_text = "\n\n".join([f"[{i}] {all_summarys[bibkey]}" for bibkey, i in bibkeys_index.items()])
|
||||
text += "\n## References\n" + reference_text
|
||||
return text
|
||||
|
||||
async def test_surveyGen(model_path, out_path,querys, url, deploy_port=None):
|
||||
|
||||
total_rollout_data = await rollout_with_env(querys, 1, 1000, model_path, url, deploy_port)
|
||||
all_md_texts = []
|
||||
for json_data in total_rollout_data:
|
||||
md_text = json_to_markdown(json_data)
|
||||
all_md_texts.append(md_text)
|
||||
|
||||
all_md_texts = "\n\n".join(all_md_texts)
|
||||
with open(out_path, 'w', encoding='utf8') as f:
|
||||
f.write(all_md_texts)
|
||||
|
||||
# with jsonlines.open(out_path, 'w') as writer:
|
||||
# for item in total_rollout_data:
|
||||
# writer.write(item)
|
||||
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
@app.post("/generate_survey")
|
||||
async def generate_survey(request: QueryRequest):
|
||||
global args # Ensure args is accessible
|
||||
# 这里可以根据需要处理查询
|
||||
model_path = args.model_path
|
||||
out_path = args.output_file
|
||||
query = request.query
|
||||
querys = [query] # 将查询转换为列表
|
||||
url = args.retriver_url
|
||||
deploy_port = args.port if args.port is not None else None
|
||||
try:
|
||||
await test_surveyGen(model_path, out_path, querys, url, deploy_port)
|
||||
return {"status": "success", "message": "Survey generated successfully."}
|
||||
except Exception as e:
|
||||
print(f"Error generating survey: {e}")
|
||||
return {"status": "error", "message": str(e)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run survey generation with vLLM.")
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path to the model.")
|
||||
parser.add_argument("--query", type=str, required=True, help="Query to generate survey.")
|
||||
parser.add_argument("--output_file", type=str, required=True, help="Path to the output Markdown file.")
|
||||
parser.add_argument("--retriver_url", type=str, default="http://localhost:8400", help="URL of the retriever service.")
|
||||
parser.add_argument("--port", type=str, default=None, help="Deploy port, default is None, which means not deploy.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.port is not None:
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="localhost", port=int(args.port))# log_level="debug")
|
||||
|
||||
# Run the survey generation
|
||||
else:
|
||||
asyncio.run(
|
||||
test_surveyGen(
|
||||
model_path=args.model_path,
|
||||
out_path=args.output_file,
|
||||
querys=[args.query],
|
||||
url=args.retriver_url
|
||||
)
|
||||
)
|
||||
|
||||
55
code/src/preprocess/build_index.py
Normal file
55
code/src/preprocess/build_index.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch.distributed
|
||||
import faiss
|
||||
import pandas as pd
|
||||
import faiss
|
||||
import numpy as np
|
||||
import jsonlines, json
|
||||
from transformers import AutoModel
|
||||
import os
|
||||
import torch
|
||||
'''
|
||||
data format:
|
||||
{
|
||||
"bibkey": "some_bibkey",
|
||||
"text": "The abstract or text of the paper."
|
||||
}
|
||||
example:
|
||||
{
|
||||
"bibkey": "arxivid1234.5678",
|
||||
"text": "Title: A Study on Something\nAbstract: This paper discusses the findings of a study on something important in the field of research.\nAuthors: John Doe"
|
||||
}
|
||||
'''
|
||||
|
||||
model_name = "openbmb/MiniCPM-Embedding-Light"
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
|
||||
|
||||
input_path = "./data/arxiv.jsonl"
|
||||
|
||||
with jsonlines.open(input_path) as f:
|
||||
survey_data = list(f)
|
||||
|
||||
|
||||
xids = [item["bibkey"] for item in survey_data]
|
||||
passages = [item["text"] for item in survey_data]
|
||||
|
||||
embeddings_doc_dense, _ = model.encode_corpus(passages, max_length=1024)
|
||||
|
||||
|
||||
# faiss save index
|
||||
index = faiss.IndexFlatIP(embeddings_doc_dense.shape[1])
|
||||
id_map_index = faiss.IndexIDMap(index)
|
||||
index = faiss.index_cpu_to_all_gpus(id_map_index)
|
||||
|
||||
x_ids_int = np.array(np.arange(len(xids)))
|
||||
|
||||
str_int_ids = {}
|
||||
for i in range(len(xids)):
|
||||
str_int_ids[xids[i]] = x_ids_int[i]
|
||||
str_int_ids_df = pd.DataFrame(str_int_ids, index=[0]).T.reset_index()
|
||||
str_int_ids_df.columns = ["str_id", "int_id"]
|
||||
str_int_ids_df.to_csv("./index/str_int_ids_abstract.csv", index=False)
|
||||
|
||||
index.add_with_ids(embeddings_doc_dense, x_ids_int)
|
||||
|
||||
index = faiss.index_gpu_to_cpu(index)
|
||||
faiss.write_index(index, "./index/index_abstract.faiss")
|
||||
21
code/src/preprocess/data_process.py
Normal file
21
code/src/preprocess/data_process.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# curl -L -o ~/Downloads/arxiv.zip\
|
||||
# https://www.kaggle.com/api/v1/datasets/download/Cornell-University/arxiv
|
||||
|
||||
|
||||
import jsonlines
|
||||
|
||||
input_path = './data/arxiv-metadata-oai-snapshot.json'
|
||||
output_path = './data/arxiv.jsonl'
|
||||
|
||||
new_data = []
|
||||
with jsonlines.open(input_path, 'r') as reader:
|
||||
for item in reader:
|
||||
new_item = {
|
||||
'bibkey': f"arxivid{item['id']}",
|
||||
'text': f"Title: {item['title']}\nAbstract: {item['abstract']}\nAuthors: {item['authors']}",
|
||||
}
|
||||
new_data.append(new_item)
|
||||
|
||||
with jsonlines.open(output_path, 'w') as writer:
|
||||
for item in new_data:
|
||||
writer.write(item)
|
||||
175
code/src/retriever/retriever.py
Normal file
175
code/src/retriever/retriever.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import faiss
|
||||
from fastapi import FastAPI
|
||||
import torch
|
||||
import pandas as pd
|
||||
from collections import defaultdict
|
||||
import pandas as pd
|
||||
import jsonlines
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
import uvicorn
|
||||
import asyncio
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
import re
|
||||
import json
|
||||
import asyncio
|
||||
import argparse
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
|
||||
model_name = "openbmb/MiniCPM-Embedding-Light"
|
||||
model = AutoModel.from_pretrained(model_name, trust_remote_code=True, attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
|
||||
co = faiss.GpuMultipleClonerOptions()
|
||||
co.shard = True
|
||||
co.useFloat16 = True
|
||||
|
||||
faiss_index_path = "./index/index_abstract.faiss" # Replace with your FAISS index path"
|
||||
faiss_index = faiss.read_index(faiss_index_path)
|
||||
faiss_index = faiss.index_cpu_to_all_gpus(faiss_index,co=co)
|
||||
|
||||
corpus_path = "./data/arxiv.jsonl"
|
||||
with jsonlines.open(corpus_path) as f:
|
||||
paper_data = list(f)
|
||||
paper_dict = {}
|
||||
item_key = "text"
|
||||
|
||||
|
||||
index_path = "./index/str_int_ids_abstract.csv"
|
||||
index_df = pd.read_csv(index_path,converters={0: lambda x: str(x),1: lambda x: int(x)})
|
||||
index_df.columns = ["str_id", "int_id"]
|
||||
index_dict = index_df.set_index("int_id")["str_id"].to_dict()
|
||||
|
||||
|
||||
for item in paper_data:
|
||||
paper_dict[item["bibkey"]] = item[item_key]
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
queries: List[str]
|
||||
topk: Optional[int] = None
|
||||
return_scores: bool = False
|
||||
|
||||
class MessageRequest(BaseModel):
|
||||
tool_calls: List
|
||||
topk: Optional[int] = 10
|
||||
|
||||
@app.post("/")
|
||||
async def search_text_batch(request:MessageRequest):
|
||||
tool_calls = request.tool_calls
|
||||
topk = request.topk
|
||||
results = []
|
||||
finalize_indices = []
|
||||
search_engine_indices = []
|
||||
for i in range(len(tool_calls)):
|
||||
try:
|
||||
tool_calls[i]["name"]
|
||||
except KeyError:
|
||||
finalize_indices.append(i)
|
||||
continue
|
||||
if tool_calls[i]["name"] == "search_engine":
|
||||
search_engine_indices.append(i)
|
||||
elif tool_calls[i]["name"] == "finalize":
|
||||
finalize_indices.append(i)
|
||||
else:
|
||||
finalize_indices.append(i)
|
||||
|
||||
|
||||
tasks = []
|
||||
for i in range(len(tool_calls)):
|
||||
if i in search_engine_indices:
|
||||
tasks.append(call_search_engine(tool_calls[i], topk))
|
||||
search_task_results = await asyncio.gather(*tasks)
|
||||
num_search = 0
|
||||
num_finalize = 0
|
||||
for i in range(len(tool_calls)):
|
||||
if i in finalize_indices:
|
||||
search_keywords, bibkeys,abstracts, done, score = "",[], [], True, 0.0
|
||||
num_finalize += 1
|
||||
elif i in search_engine_indices:
|
||||
search_keywords, bibkeys, abstracts, done, score = search_task_results[num_search]
|
||||
num_search += 1
|
||||
|
||||
titles = []
|
||||
for abstract in abstracts:
|
||||
try:
|
||||
title = abstract.split("\n")[1]
|
||||
title = title.split(":")[1].strip()
|
||||
titles.append(title)
|
||||
except:
|
||||
titles.append("")
|
||||
results.append({ "search_keywords":search_keywords, "summarys":abstracts, "done":done, "score":score, "titles":titles, "bibkeys":bibkeys})
|
||||
|
||||
|
||||
return results
|
||||
|
||||
def extract_tool_call(text: str):
|
||||
text = text.strip()
|
||||
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if not match:
|
||||
return None
|
||||
tool_text = match.group(1)
|
||||
try:
|
||||
tool_call = json.loads(tool_text)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return tool_call if isinstance(tool_call, dict) else None
|
||||
|
||||
|
||||
def get_response(queries,ref):
|
||||
text_raw = paper_dict[str(ref)]
|
||||
text_raw = tokenizer(text_raw, max_length=8192, truncation=True)
|
||||
text_raw = tokenizer.decode(text_raw["input_ids"])
|
||||
|
||||
response = text_raw
|
||||
response = f"bibkey: {str(ref)}\n"+response
|
||||
return response
|
||||
|
||||
async def call_search_engine(tool_call, topk=10):
|
||||
try:
|
||||
queries = tool_call["arguments"]["query"]
|
||||
if isinstance(queries, str):
|
||||
queries = [queries]
|
||||
else:
|
||||
queries = list(queries)
|
||||
|
||||
if len(queries) == 0:
|
||||
return "", [], [], False, 0.0
|
||||
results = defaultdict(dict)
|
||||
query_embedding_to_text,_ = model.encode_query(queries, max_length=512, show_progress_bar=False)
|
||||
_,results = faiss_index.search(query_embedding_to_text, topk)
|
||||
result2query = {}
|
||||
merge_rrf = defaultdict(float)
|
||||
for i in range(len(results)):
|
||||
for j in range(len(results[i])):
|
||||
merge_rrf[results[i][j]] += 1/(j+1)
|
||||
result2query[results[i][j]] = queries[i]
|
||||
results = sorted(merge_rrf.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
results = [x[0] for x in results][:topk]
|
||||
|
||||
# new_queries = [result2query[result] for result in results]
|
||||
queries = ",".join(queries)
|
||||
|
||||
# bibkeys = [str(results[i]) for i in range(len(results))]
|
||||
bibkeys = [str(index_dict[results[i]]) for i in range(len(results))]
|
||||
response = [f"bibkey: {bibkey}\n{paper_dict[bibkey]}" for bibkey in bibkeys]
|
||||
return queries,bibkeys , response, False, 0.0
|
||||
except Exception as e:
|
||||
print(f"Error in call_search_engine: {e}")
|
||||
return "",[], [], False, 0.0
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the FastAPI application.")
|
||||
parser.add_argument("--port", type=int, default=8400, help="Port to run the FastAPI application on.")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
||||
Reference in New Issue
Block a user