4039 lines
148 KiB
Python
4039 lines
148 KiB
Python
#!/usr/bin/env python3
|
|
"""chat.py — Interactive CLI for vLLM models with tool calling.
|
|
|
|
Single-file CLI tool for chatting with persona_kappa (or any model) running
|
|
on vLLM. Supports multi-turn conversation, bash tool calling, streaming,
|
|
and persona conditioning.
|
|
|
|
Usage:
|
|
python3 chat.py # auto-detect model
|
|
python3 chat.py --persona lawful_evil # with persona
|
|
python3 chat.py --model kappa_20b_131k # explicit model
|
|
"""
|
|
|
|
import argparse
|
|
import ast
|
|
import concurrent.futures
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
import re
|
|
import readline
|
|
import sqlite3
|
|
import subprocess
|
|
import sys
|
|
import uuid
|
|
import ipaddress
|
|
import socket
|
|
import tempfile
|
|
import textwrap
|
|
import threading
|
|
import traceback
|
|
from html import unescape as _html_unescape
|
|
from urllib.error import URLError
|
|
from urllib.parse import urlparse
|
|
from urllib.request import Request, urlopen
|
|
|
|
from openai import OpenAI
|
|
|
|
# ─── ANSI helpers ──────────────────────────────────────────────────────────
|
|
|
|
RESET = "\033[0m"
|
|
BOLD = "\033[1m"
|
|
DIM = "\033[2m"
|
|
ITALIC = "\033[3m"
|
|
RED = "\033[31m"
|
|
GREEN = "\033[32m"
|
|
YELLOW = "\033[33m"
|
|
BLUE = "\033[34m"
|
|
MAGENTA = "\033[35m"
|
|
CYAN = "\033[36m"
|
|
GRAY = "\033[90m"
|
|
|
|
|
|
def red(s):
|
|
return f"{RED}{s}{RESET}"
|
|
|
|
|
|
def yellow(s):
|
|
return f"{YELLOW}{s}{RESET}"
|
|
|
|
|
|
def dim(s):
|
|
return f"{DIM}{s}{RESET}"
|
|
|
|
|
|
def bold(s):
|
|
return f"{BOLD}{s}{RESET}"
|
|
|
|
|
|
def cyan(s):
|
|
return f"{CYAN}{s}{RESET}"
|
|
|
|
|
|
def green(s):
|
|
return f"{GREEN}{s}{RESET}"
|
|
|
|
|
|
# ─── Markdown → ANSI renderer ─────────────────────────────────────────────
|
|
|
|
|
|
class MarkdownRenderer:
|
|
"""Line-buffered markdown → ANSI converter for streaming output.
|
|
|
|
Buffers content until a newline arrives, then renders the complete line
|
|
with regex-based markdown → ANSI conversion. Multi-line constructs
|
|
(fenced code blocks) track state across lines.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.in_code_block = False
|
|
self._buf = ""
|
|
|
|
def feed(self, text: str) -> str:
|
|
"""Feed text, return ANSI-rendered output for complete lines."""
|
|
self._buf += text
|
|
out = []
|
|
while "\n" in self._buf:
|
|
line, self._buf = self._buf.split("\n", 1)
|
|
out.append(self._render_line(line))
|
|
return "\n".join(out) + "\n" if out else ""
|
|
|
|
def flush(self) -> str:
|
|
"""Flush remaining buffer (end of stream)."""
|
|
if self._buf:
|
|
rendered = self._render_line(self._buf)
|
|
self._buf = ""
|
|
return rendered
|
|
return ""
|
|
|
|
def _render_line(self, line: str) -> str:
|
|
# Code block fence toggle
|
|
if line.strip().startswith("```"):
|
|
self.in_code_block = not self.in_code_block
|
|
return f"{DIM}{line}{RESET}"
|
|
|
|
# Inside code block — cyan, no further markdown processing
|
|
if self.in_code_block:
|
|
return f"{CYAN}{line}{RESET}"
|
|
|
|
# Headers (# H1, ## H2, ### H3, #### H4, ##### H5, ###### H6)
|
|
m = re.match(r"^(#{1,6}) (.+)", line)
|
|
if m:
|
|
return f"{BOLD}{MAGENTA}{m.group(2)}{RESET}"
|
|
|
|
# Inline formatting (order matters: bold before italic)
|
|
line = re.sub(r"\*\*(.+?)\*\*", f"{BOLD}\\1{RESET}", line)
|
|
line = re.sub(r"__(.+?)__", f"{BOLD}\\1{RESET}", line)
|
|
line = re.sub(
|
|
r"(?<!\*)\*(?!\*)(.+?)(?<!\*)\*(?!\*)", f"{ITALIC}\\1{RESET}", line
|
|
)
|
|
line = re.sub(r"`(.+?)`", f"{CYAN}\\1{RESET}", line)
|
|
|
|
# Bullet lists — cyan bullet
|
|
line = re.sub(r"^(\s*)([-*]) ", f"\\1{CYAN}\\2{RESET} ", line)
|
|
|
|
# Numbered lists — cyan number
|
|
line = re.sub(r"^(\s*)(\d+)\. ", f"\\1{CYAN}\\2.{RESET} ", line)
|
|
|
|
return line
|
|
|
|
|
|
# ─── Tool definition ──────────────────────────────────────────────────────
|
|
|
|
TOOLS = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "bash",
|
|
"description": (
|
|
"Execute a bash command and return stdout + stderr. "
|
|
"Use this tool freely for any task: checking time, "
|
|
"reading files, running programs, system info, etc."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"command": {
|
|
"type": "string",
|
|
"description": "The bash command to execute.",
|
|
}
|
|
},
|
|
"required": ["command"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "write_file",
|
|
"description": "Write content to a file, creating it if needed. Overwrites existing content.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {
|
|
"type": "string",
|
|
"description": "Absolute or relative file path.",
|
|
},
|
|
"content": {
|
|
"type": "string",
|
|
"description": "The full file content to write.",
|
|
},
|
|
},
|
|
"required": ["path", "content"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"description": "Read the contents of a file. Returns numbered lines. Must be called before edit_file on the same path.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {
|
|
"type": "string",
|
|
"description": "Absolute or relative file path.",
|
|
},
|
|
"offset": {
|
|
"type": "integer",
|
|
"description": "Line number to start reading from (1-based, default: 1).",
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Maximum number of lines to read. Omit to read entire file.",
|
|
},
|
|
},
|
|
"required": ["path"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "search",
|
|
"description": (
|
|
"Search file contents for a regex pattern. "
|
|
"Returns matching lines with file paths and line numbers. "
|
|
"Searches recursively when path is a directory."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Regex pattern to search for (extended regex).",
|
|
},
|
|
"path": {
|
|
"type": "string",
|
|
"description": "File or directory to search in (default: current directory).",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "edit_file",
|
|
"description": "Replace an exact string in a file with new content. Fails if old_string is not found or matches multiple locations (use near_line to disambiguate). Requires read_file on the same path first.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {
|
|
"type": "string",
|
|
"description": "Absolute or relative file path.",
|
|
},
|
|
"old_string": {
|
|
"type": "string",
|
|
"description": "The exact text to find and replace.",
|
|
},
|
|
"new_string": {
|
|
"type": "string",
|
|
"description": "The replacement text.",
|
|
},
|
|
"near_line": {
|
|
"type": "integer",
|
|
"description": "When old_string matches multiple locations, pick the one nearest this line number.",
|
|
},
|
|
},
|
|
"required": ["path", "old_string", "new_string"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "math",
|
|
"description": (
|
|
"Execute Python code for math/computation. "
|
|
"Code MUST use print() to produce output. "
|
|
"Available: sympy, numpy, scipy (with scipy.special, "
|
|
"scipy.optimize, scipy.integrate, scipy.linalg), math, "
|
|
"fractions, itertools, functools, collections, decimal, "
|
|
"operator, random, re, string. "
|
|
"Common sympy names (symbols, solve, simplify, expand, factor, "
|
|
"sqrt, Rational, Matrix, integrate, diff, etc.) are pre-imported. "
|
|
"Example: x = symbols('x'); print(solve(x**2 - 4, x))"
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"code": {
|
|
"type": "string",
|
|
"description": "Python code to execute. Must use print() for output.",
|
|
},
|
|
},
|
|
"required": ["code"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "man",
|
|
"description": (
|
|
"Read a man page. Use this instead of bash('man ...') or "
|
|
"web_search. Returns the full formatted manual entry."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"page": {
|
|
"type": "string",
|
|
"description": "The man page name (e.g. 'grep', 'socket', 'printf').",
|
|
},
|
|
"section": {
|
|
"type": "string",
|
|
"description": "Manual section (e.g. '1' commands, '2' syscalls, '3' library). Optional.",
|
|
},
|
|
},
|
|
"required": ["page"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "web_fetch",
|
|
"description": (
|
|
"Fetch a URL and extract specific information from it. "
|
|
"You must provide a question or extraction guidance — "
|
|
"the page is fetched, analyzed, and only relevant "
|
|
"information is returned (not raw page content)."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"url": {
|
|
"type": "string",
|
|
"description": "The URL to fetch (must start with http:// or https://).",
|
|
},
|
|
"question": {
|
|
"type": "string",
|
|
"description": "What to extract or answer from the page content.",
|
|
},
|
|
},
|
|
"required": ["url", "question"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "web_search",
|
|
"description": (
|
|
"Search the web using a text query. Returns ranked results "
|
|
"with titles, URLs, and content snippets."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "The search query.",
|
|
},
|
|
"max_results": {
|
|
"type": "integer",
|
|
"description": "Max results to return (default 5, max 20).",
|
|
},
|
|
"topic": {
|
|
"type": "string",
|
|
"description": "Search topic: general, news, or finance (default general).",
|
|
},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "task",
|
|
"description": (
|
|
"Delegate a general-purpose task to an autonomous sub-agent. "
|
|
"The agent inherits all tools and can read, write, edit, search, "
|
|
"and run commands. Use task for work that requires file modifications "
|
|
"or command execution. Provide a clear, self-contained prompt."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "Complete task description for the sub-agent.",
|
|
},
|
|
},
|
|
"required": ["prompt"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "plan",
|
|
"description": (
|
|
"Plan before implementing. An autonomous agent explores the "
|
|
"codebase and writes a structured plan to .plan.md. "
|
|
"Use plan BEFORE writing code — when the user asks to build, "
|
|
"add, refactor, or change something that touches multiple files "
|
|
"or has unclear scope. The plan identifies files to modify, "
|
|
"existing patterns to reuse, and risks to consider."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "What to plan — the goal, constraints, and scope.",
|
|
},
|
|
},
|
|
"required": ["prompt"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "remember",
|
|
"description": (
|
|
"Save a persistent memory. Memories persist across sessions. "
|
|
"Use to remember IPs, paths, commands, conventions, or any "
|
|
"fact worth recalling later."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"key": {
|
|
"type": "string",
|
|
"description": "Short identifier (e.g. 'user_name').",
|
|
},
|
|
"value": {
|
|
"type": "string",
|
|
"description": "Content to remember.",
|
|
},
|
|
},
|
|
"required": ["key", "value"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "recall",
|
|
"description": (
|
|
"Search memories and past conversations. "
|
|
"With no query, lists all saved memories. "
|
|
"With a query, searches both memories and conversation history."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {
|
|
"type": "string",
|
|
"description": "Search term or phrase. Omit to list all memories.",
|
|
},
|
|
"limit": {
|
|
"type": "integer",
|
|
"description": "Max conversation results to return (default 20).",
|
|
},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "forget",
|
|
"description": (
|
|
"Remove a persistent memory by key. Use when the user asks "
|
|
"to forget, remove, or delete a stored memory."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"key": {
|
|
"type": "string",
|
|
"description": "The memory key to remove (e.g. 'user_name').",
|
|
},
|
|
},
|
|
"required": ["key"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
# Tools available to read-only sub-agents (plan).
|
|
# Excludes agent tools (no recursion) and write tools (security boundary).
|
|
AGENT_TOOLS = [
|
|
t
|
|
for t in TOOLS
|
|
if t["function"]["name"]
|
|
in ("read_file", "search", "math", "man", "web_fetch", "web_search")
|
|
]
|
|
|
|
# Tools available to the task sub-agent (read + write + execute).
|
|
# Still excludes agent tools (no recursion).
|
|
TASK_AGENT_TOOLS = [
|
|
t
|
|
for t in TOOLS
|
|
if t["function"]["name"]
|
|
in (
|
|
"read_file",
|
|
"search",
|
|
"math",
|
|
"man",
|
|
"web_fetch",
|
|
"web_search",
|
|
"bash",
|
|
"write_file",
|
|
"edit_file",
|
|
)
|
|
]
|
|
|
|
# ─── Edit helpers ─────────────────────────────────────────────────────────
|
|
|
|
|
|
def _find_occurrences(content: str, old_string: str) -> list[int]:
|
|
"""Return 1-based line numbers where each occurrence of old_string starts."""
|
|
if not old_string:
|
|
return []
|
|
# Build a prefix-sum of line starts for O(1) line-number lookup.
|
|
line_starts = [0]
|
|
for i, ch in enumerate(content):
|
|
if ch == "\n":
|
|
line_starts.append(i + 1)
|
|
results = []
|
|
start = 0
|
|
while True:
|
|
idx = content.find(old_string, start)
|
|
if idx == -1:
|
|
break
|
|
# bisect: find the line containing idx
|
|
lo, hi = 0, len(line_starts) - 1
|
|
while lo < hi:
|
|
mid = (lo + hi + 1) // 2
|
|
if line_starts[mid] <= idx:
|
|
lo = mid
|
|
else:
|
|
hi = mid - 1
|
|
results.append(lo + 1) # 1-based
|
|
start = idx + 1
|
|
return results
|
|
|
|
|
|
def _pick_nearest(content: str, old_string: str, near_line: int) -> int:
|
|
"""Return the char index of the occurrence of old_string nearest to near_line."""
|
|
line_starts = [0]
|
|
for i, ch in enumerate(content):
|
|
if ch == "\n":
|
|
line_starts.append(i + 1)
|
|
|
|
best_idx = -1
|
|
best_dist = float("inf")
|
|
start = 0
|
|
while True:
|
|
idx = content.find(old_string, start)
|
|
if idx == -1:
|
|
break
|
|
# Find line number for this occurrence
|
|
lo, hi = 0, len(line_starts) - 1
|
|
while lo < hi:
|
|
mid = (lo + hi + 1) // 2
|
|
if line_starts[mid] <= idx:
|
|
lo = mid
|
|
else:
|
|
hi = mid - 1
|
|
line_num = lo + 1
|
|
dist = abs(line_num - near_line)
|
|
if dist < best_dist:
|
|
best_dist = dist
|
|
best_idx = idx
|
|
start = idx + 1
|
|
return best_idx
|
|
|
|
|
|
# ─── Sandboxed Python executor ────────────────────────────────────────────
|
|
|
|
_MATH_BLOCKED_BUILTINS = {
|
|
"open",
|
|
"exec",
|
|
"eval",
|
|
"compile",
|
|
"input",
|
|
"breakpoint",
|
|
"memoryview",
|
|
"globals",
|
|
"locals",
|
|
"vars",
|
|
}
|
|
|
|
_MATH_BLOCKED_MODULES = {
|
|
"os",
|
|
"sys",
|
|
"subprocess",
|
|
"shutil",
|
|
"pathlib",
|
|
"socket",
|
|
"http",
|
|
"urllib",
|
|
"requests",
|
|
"pickle",
|
|
"marshal",
|
|
"shelve",
|
|
"dbm",
|
|
"sqlite3",
|
|
"ctypes",
|
|
"multiprocessing",
|
|
"threading",
|
|
"asyncio",
|
|
"concurrent",
|
|
"signal",
|
|
"pty",
|
|
"tty",
|
|
"termios",
|
|
"fcntl",
|
|
"resource",
|
|
"syslog",
|
|
"tempfile",
|
|
"io",
|
|
"builtins",
|
|
"__builtin__",
|
|
"importlib",
|
|
}
|
|
|
|
|
|
class _ASTValidator(ast.NodeVisitor):
|
|
"""Validates AST for dangerous constructs."""
|
|
|
|
def __init__(self):
|
|
self.errors: list[str] = []
|
|
|
|
def visit_Import(self, node):
|
|
for alias in node.names:
|
|
if alias.name.split(".")[0] in _MATH_BLOCKED_MODULES:
|
|
self.errors.append(f"Import of '{alias.name}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_ImportFrom(self, node):
|
|
if node.module and node.module.split(".")[0] in _MATH_BLOCKED_MODULES:
|
|
self.errors.append(f"Import from '{node.module}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Call(self, node):
|
|
if isinstance(node.func, ast.Name) and node.func.id in _MATH_BLOCKED_BUILTINS:
|
|
self.errors.append(f"Call to '{node.func.id}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
def visit_Attribute(self, node):
|
|
if node.attr.startswith("__") and node.attr.endswith("__"):
|
|
if node.attr not in {"__name__", "__doc__", "__class__"}:
|
|
self.errors.append(f"Access to '{node.attr}' is not allowed")
|
|
self.generic_visit(node)
|
|
|
|
|
|
def _validate_math_code(code: str) -> list[str]:
|
|
"""Validate code for dangerous constructs. Returns list of errors."""
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError as e:
|
|
lines = code.split("\n")
|
|
msg = f"Syntax error on line {e.lineno}: {e.msg}"
|
|
if e.lineno and e.lineno <= len(lines):
|
|
msg += f"\n {e.lineno}: {lines[e.lineno - 1]}"
|
|
if e.offset:
|
|
msg += f"\n {' ' * (e.offset - 1)}^"
|
|
return [msg]
|
|
except (ValueError, UnicodeError) as e:
|
|
return [f"Code contains invalid characters: {e}"]
|
|
v = _ASTValidator()
|
|
v.visit(tree)
|
|
return v.errors
|
|
|
|
|
|
def _math_exec_in_process(code: str, result_queue: multiprocessing.Queue):
|
|
"""Execute code in a subprocess, put (status, output) in queue."""
|
|
import signal as _signal
|
|
import sys as _sys
|
|
from io import StringIO
|
|
|
|
_signal.signal(_signal.SIGTERM, _signal.SIG_DFL)
|
|
_signal.signal(_signal.SIGINT, _signal.SIG_DFL)
|
|
_sys.set_int_max_str_digits(100_000)
|
|
|
|
try:
|
|
captured = StringIO()
|
|
_sys.stdout = captured
|
|
|
|
def _safe_import(name, *args, **kwargs):
|
|
if name.split(".")[0] in _MATH_BLOCKED_MODULES:
|
|
raise ImportError(f"Import of '{name}' is blocked")
|
|
return original_import(name, *args, **kwargs)
|
|
|
|
original_import = (
|
|
__builtins__["__import__"]
|
|
if isinstance(__builtins__, dict)
|
|
else __builtins__.__import__
|
|
)
|
|
safe_builtins = (
|
|
{k: v for k, v in __builtins__.items() if k not in _MATH_BLOCKED_BUILTINS}
|
|
if isinstance(__builtins__, dict)
|
|
else {
|
|
k: getattr(__builtins__, k)
|
|
for k in dir(__builtins__)
|
|
if k not in _MATH_BLOCKED_BUILTINS and not k.startswith("_")
|
|
}
|
|
)
|
|
safe_builtins["__import__"] = _safe_import
|
|
|
|
# Pre-import safe modules
|
|
import math, fractions, itertools, functools, operator
|
|
import collections, decimal, random, re, string
|
|
|
|
ns: dict = {
|
|
"__builtins__": safe_builtins,
|
|
"math": math,
|
|
"fractions": fractions,
|
|
"Fraction": fractions.Fraction,
|
|
"itertools": itertools,
|
|
"functools": functools,
|
|
"operator": operator,
|
|
"collections": collections,
|
|
"decimal": decimal,
|
|
"Decimal": decimal.Decimal,
|
|
"random": random,
|
|
"re": re,
|
|
"string": string,
|
|
}
|
|
|
|
try:
|
|
import sympy
|
|
|
|
ns["sympy"] = sympy
|
|
for name in (
|
|
"symbols",
|
|
"Symbol",
|
|
"solve",
|
|
"simplify",
|
|
"expand",
|
|
"factor",
|
|
"Eq",
|
|
"sqrt",
|
|
"Rational",
|
|
"pi",
|
|
"E",
|
|
"I",
|
|
"oo",
|
|
"sin",
|
|
"cos",
|
|
"tan",
|
|
"exp",
|
|
"log",
|
|
"factorial",
|
|
"binomial",
|
|
"gcd",
|
|
"lcm",
|
|
"prime",
|
|
"isprime",
|
|
"factorint",
|
|
"divisors",
|
|
"totient",
|
|
"mod_inverse",
|
|
"Matrix",
|
|
"integrate",
|
|
"diff",
|
|
"limit",
|
|
"series",
|
|
"Sum",
|
|
"Product",
|
|
"floor",
|
|
"ceiling",
|
|
"Abs",
|
|
):
|
|
ns[name] = getattr(sympy, name)
|
|
except ImportError:
|
|
pass
|
|
|
|
try:
|
|
import numpy as _np
|
|
|
|
ns["np"] = ns["numpy"] = _np
|
|
except ImportError:
|
|
pass
|
|
|
|
try:
|
|
import scipy, scipy.special, scipy.optimize, scipy.integrate, scipy.linalg
|
|
|
|
ns["scipy"] = scipy
|
|
ns["special"] = scipy.special
|
|
ns["optimize"] = scipy.optimize
|
|
ns["comb"] = scipy.special.comb
|
|
ns["perm"] = scipy.special.perm
|
|
ns["gamma"] = scipy.special.gamma
|
|
ns["beta"] = scipy.special.beta
|
|
except ImportError:
|
|
pass
|
|
|
|
exec(code, ns)
|
|
|
|
_sys.stdout = _sys.__stdout__
|
|
printed = captured.getvalue()
|
|
result_var = ns.get("result")
|
|
if result_var is not None:
|
|
out = (
|
|
f"{printed.rstrip()}\nresult = {result_var}"
|
|
if printed
|
|
else str(result_var)
|
|
)
|
|
elif printed:
|
|
out = printed.rstrip()
|
|
else:
|
|
out = "No output. Add print() to see results."
|
|
result_queue.put(("success", out))
|
|
|
|
except Exception as e:
|
|
_sys.stdout = _sys.__stdout__
|
|
result_queue.put(
|
|
("error", f"{type(e).__name__}: {e}\n{traceback.format_exc()}")
|
|
)
|
|
|
|
|
|
def _auto_print_wrap(code: str) -> str:
|
|
"""If code has no print/result and the last statement is an expression, wrap it in print()."""
|
|
# Skip if code already has print() or assigns to 'result'
|
|
if "print(" in code or re.search(r"\bresult\s*=", code):
|
|
return code
|
|
try:
|
|
tree = ast.parse(code)
|
|
except SyntaxError:
|
|
return code
|
|
if not tree.body:
|
|
return code
|
|
last = tree.body[-1]
|
|
if isinstance(last, ast.Expr):
|
|
# Get the source of the last expression and wrap in print()
|
|
lines = code.split("\n")
|
|
last_line_start = last.lineno - 1 # 0-based
|
|
last_line_end = last.end_lineno # 1-based, exclusive after slicing
|
|
expr_lines = lines[last_line_start:last_line_end]
|
|
expr_text = "\n".join(expr_lines)
|
|
prefix = lines[:last_line_start]
|
|
wrapped = prefix + [f"print({expr_text})"]
|
|
return "\n".join(wrapped)
|
|
return code
|
|
|
|
|
|
def _execute_math_sandboxed(code: str, timeout: float = 30.0) -> tuple[str, bool]:
|
|
"""Execute Python code in a sandboxed subprocess. Returns (output, is_error)."""
|
|
code = _auto_print_wrap(code)
|
|
errors = _validate_math_code(code)
|
|
if errors:
|
|
return "Validation errors:\n" + "\n".join(f"- {e}" for e in errors), True
|
|
|
|
result_queue: multiprocessing.Queue = multiprocessing.Queue()
|
|
proc = multiprocessing.Process(
|
|
target=_math_exec_in_process, args=(code, result_queue)
|
|
)
|
|
proc.start()
|
|
proc.join(timeout=timeout)
|
|
|
|
if proc.is_alive():
|
|
proc.terminate()
|
|
proc.join(timeout=1.0)
|
|
if proc.is_alive():
|
|
proc.kill()
|
|
proc.join()
|
|
result_queue.close()
|
|
result_queue.join_thread()
|
|
return f"Execution timed out after {timeout}s", True
|
|
|
|
if result_queue.empty():
|
|
result_queue.close()
|
|
result_queue.join_thread()
|
|
return "Execution failed with no output", True
|
|
|
|
status, output = result_queue.get()
|
|
result_queue.close()
|
|
result_queue.join_thread()
|
|
return output, status == "error"
|
|
|
|
|
|
# ─── HTML stripper (for web_fetch) ─────────────────────────────────────────
|
|
|
|
_RE_TAGS = re.compile(r"<[^>]+>")
|
|
_RE_WS = re.compile(r"[ \t]+")
|
|
_RE_BLANKLINES = re.compile(r"\n{3,}")
|
|
|
|
|
|
def _strip_html(html: str) -> str:
|
|
"""Convert HTML to plain text: strip tags, decode entities, collapse whitespace."""
|
|
text = _RE_TAGS.sub("", html)
|
|
text = _html_unescape(text)
|
|
text = _RE_WS.sub(" ", text)
|
|
text = _RE_BLANKLINES.sub("\n\n", text)
|
|
return text.strip()
|
|
|
|
|
|
# ─── Safety ────────────────────────────────────────────────────────────────
|
|
|
|
# Soft guardrail — catches common accidental destructive commands but is
|
|
# trivially bypassable (e.g. extra spaces, shell variable expansion).
|
|
# The user approval prompt is the primary security boundary.
|
|
BLOCKED_PATTERNS = [
|
|
"rm -rf /",
|
|
"rm -rf /*",
|
|
"mkfs",
|
|
"shutdown",
|
|
"reboot",
|
|
"halt",
|
|
"poweroff",
|
|
"dd if=",
|
|
":(){ :|:& };:", # fork bomb
|
|
"> /dev/sda",
|
|
"mv / ",
|
|
"chmod -R 777 /",
|
|
"chown -R ",
|
|
]
|
|
|
|
|
|
def _sanitize_command(cmd: str) -> str:
|
|
"""Replace common unicode look-alikes that break the shell."""
|
|
return (
|
|
cmd.replace("\u2018", "'") # left single curly quote
|
|
.replace("\u2019", "'") # right single curly quote
|
|
.replace("\u201c", '"') # left double curly quote
|
|
.replace("\u201d", '"') # right double curly quote
|
|
.replace("\u2013", "-") # en dash
|
|
.replace("\u2014", "-") # em dash
|
|
)
|
|
|
|
|
|
def is_command_blocked(cmd: str) -> str | None:
|
|
"""Return reason string if command is blocked, None otherwise."""
|
|
cmd_stripped = cmd.strip()
|
|
for pattern in BLOCKED_PATTERNS:
|
|
if pattern in cmd_stripped:
|
|
return f"Blocked: command matches dangerous pattern '{pattern}'"
|
|
return None
|
|
|
|
|
|
# ─── Readline setup ───────────────────────────────────────────────────────
|
|
|
|
PCODE_DB = os.path.join(os.getcwd(), ".pcode_memories.db")
|
|
_db_override: str | None = None
|
|
_db_initialized: set[str] = set()
|
|
_fts5_available: bool = False
|
|
|
|
_tavily_key: str | None = None
|
|
_tavily_key_loaded: bool = False
|
|
|
|
|
|
def _get_tavily_key() -> str | None:
|
|
"""Load Tavily API key from file or env var (cached after first call)."""
|
|
global _tavily_key, _tavily_key_loaded
|
|
if _tavily_key_loaded:
|
|
return _tavily_key
|
|
_tavily_key_loaded = True
|
|
key_path = os.path.expanduser("~/.config/pcode/tavily_key")
|
|
if os.path.isfile(key_path):
|
|
try:
|
|
with open(key_path) as f:
|
|
key = f.read().strip()
|
|
if key:
|
|
_tavily_key = key
|
|
return _tavily_key
|
|
except OSError:
|
|
pass
|
|
env_key = os.environ.get("TAVILY_API_KEY", "").strip()
|
|
if env_key:
|
|
_tavily_key = env_key
|
|
return _tavily_key
|
|
|
|
|
|
def _open_db() -> sqlite3.Connection:
|
|
"""Open the pcode database, creating tables on first use per path."""
|
|
path = _db_override or PCODE_DB
|
|
conn = sqlite3.connect(path)
|
|
if path not in _db_initialized:
|
|
conn.execute(
|
|
"CREATE TABLE IF NOT EXISTS memories "
|
|
"(key TEXT PRIMARY KEY, value TEXT NOT NULL, "
|
|
"created TEXT NOT NULL, updated TEXT NOT NULL)"
|
|
)
|
|
conn.execute(
|
|
"CREATE TABLE IF NOT EXISTS conversations "
|
|
"(id INTEGER PRIMARY KEY AUTOINCREMENT, "
|
|
"session_id TEXT NOT NULL, timestamp TEXT NOT NULL, "
|
|
"role TEXT NOT NULL, content TEXT, "
|
|
"tool_name TEXT, tool_args TEXT)"
|
|
)
|
|
conn.execute(
|
|
"CREATE INDEX IF NOT EXISTS idx_conv_session ON conversations(session_id)"
|
|
)
|
|
global _fts5_available
|
|
try:
|
|
# Check if FTS table already exists
|
|
fts_exists = conn.execute(
|
|
"SELECT 1 FROM sqlite_master "
|
|
"WHERE type='table' AND name='conversations_fts'"
|
|
).fetchone()
|
|
if not fts_exists:
|
|
conn.execute(
|
|
"CREATE VIRTUAL TABLE conversations_fts "
|
|
"USING fts5(content, content=conversations, content_rowid=id)"
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO conversations_fts(conversations_fts) VALUES('rebuild')"
|
|
)
|
|
conn.commit()
|
|
_fts5_available = True
|
|
except Exception:
|
|
_fts5_available = False
|
|
_db_initialized.add(path)
|
|
return conn
|
|
|
|
|
|
def _normalize_key(key: str) -> str:
|
|
"""Normalize a memory key for consistent lookup."""
|
|
return key.lower().replace("-", "_").replace(" ", "_")
|
|
|
|
|
|
def _load_memories() -> list[tuple[str, str]]:
|
|
"""Return all (key, value) pairs sorted by key."""
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
return conn.execute(
|
|
"SELECT key, value FROM memories ORDER BY key"
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def _save_message(
|
|
session_id: str,
|
|
role: str,
|
|
content: str | None,
|
|
tool_name: str | None = None,
|
|
tool_args: str | None = None,
|
|
) -> None:
|
|
"""Log a message to the conversations table."""
|
|
global _fts5_available
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
conn.execute(
|
|
"INSERT INTO conversations (session_id, timestamp, role, content, "
|
|
"tool_name, tool_args) VALUES (?, datetime('now'), ?, ?, ?, ?)",
|
|
(session_id, role, content, tool_name, tool_args),
|
|
)
|
|
if _fts5_available and content:
|
|
try:
|
|
rowid = conn.execute("SELECT last_insert_rowid()").fetchone()[0]
|
|
conn.execute(
|
|
"INSERT INTO conversations_fts(rowid, content) VALUES (?, ?)",
|
|
(rowid, content),
|
|
)
|
|
except Exception:
|
|
_fts5_available = False # degrade to LIKE for rest of session
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
except Exception:
|
|
pass # Don't let logging failures break the session
|
|
|
|
|
|
def _escape_like(s: str) -> str:
|
|
"""Escape LIKE metacharacters for use with ESCAPE '\\'."""
|
|
return s.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
|
|
|
|
|
def _fts5_query(query: str) -> str:
|
|
"""Convert a plain search string into a safe FTS5 query.
|
|
|
|
Quotes each term so FTS5 special characters (*, -, etc.) are treated
|
|
as literals, then joins with implicit AND. Embedded double quotes
|
|
are doubled per FTS5 quoting convention.
|
|
"""
|
|
terms = query.split()
|
|
safe = []
|
|
for t in terms:
|
|
if t:
|
|
safe.append(f'"{t.replace(chr(34), chr(34) + chr(34))}"')
|
|
return " ".join(safe)
|
|
|
|
|
|
def _search_history(query: str, limit: int = 20) -> list[tuple]:
|
|
"""Search conversation history. Returns (timestamp, session_id, role, content, tool_name)."""
|
|
if not query or not query.strip():
|
|
return []
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
if _fts5_available:
|
|
return conn.execute(
|
|
"SELECT c.timestamp, c.session_id, c.role, c.content, c.tool_name "
|
|
"FROM conversations_fts f "
|
|
"JOIN conversations c ON c.id = f.rowid "
|
|
"WHERE conversations_fts MATCH ? "
|
|
"ORDER BY f.rank ASC LIMIT ?",
|
|
(_fts5_query(query), min(limit, 100)),
|
|
).fetchall()
|
|
return conn.execute(
|
|
"SELECT timestamp, session_id, role, content, tool_name "
|
|
"FROM conversations WHERE content LIKE ? ESCAPE '\\' "
|
|
"ORDER BY timestamp DESC LIMIT ?",
|
|
(f"%{_escape_like(query)}%", min(limit, 100)),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
def _search_history_recent(limit: int = 20) -> list[tuple]:
|
|
"""Return most recent conversation messages."""
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
return conn.execute(
|
|
"SELECT timestamp, session_id, role, content, tool_name "
|
|
"FROM conversations ORDER BY timestamp DESC LIMIT ?",
|
|
(min(limit, 100),),
|
|
).fetchall()
|
|
finally:
|
|
conn.close()
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
SLASH_COMMANDS = [
|
|
"/persona",
|
|
"/instructions",
|
|
"/clear",
|
|
"/new",
|
|
"/history",
|
|
"/model",
|
|
"/raw",
|
|
"/reason",
|
|
"/compact",
|
|
"/creative",
|
|
"/debug",
|
|
"/help",
|
|
"/exit",
|
|
"/quit",
|
|
"/q",
|
|
]
|
|
|
|
|
|
def _completer(text, state):
|
|
"""Tab-complete slash commands."""
|
|
if text.startswith("/"):
|
|
matches = [c for c in SLASH_COMMANDS if c.startswith(text)]
|
|
else:
|
|
matches = []
|
|
if state < len(matches):
|
|
return matches[state] + " "
|
|
return None
|
|
|
|
|
|
def setup_readline():
|
|
"""Set up readline with tab completion."""
|
|
readline.set_history_length(1000)
|
|
readline.set_completer(_completer)
|
|
readline.set_completer_delims("") # treat entire line as completion input
|
|
readline.parse_and_bind("tab: complete")
|
|
|
|
|
|
# ─── Spinner ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
class Spinner:
|
|
"""Simple terminal spinner for waiting periods.
|
|
|
|
Supports both context manager and explicit start/stop usage.
|
|
Call stop() early (e.g., on first streaming token) to clear the spinner
|
|
before printing content.
|
|
"""
|
|
|
|
FRAMES = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
|
|
|
def __init__(self, message="Thinking"):
|
|
self.message = message
|
|
self._stop_event = threading.Event()
|
|
self._thread = None
|
|
self._stopped = True
|
|
|
|
def start(self):
|
|
self._stop_event.clear()
|
|
self._stopped = False
|
|
self._thread = threading.Thread(target=self._spin, daemon=True)
|
|
self._thread.start()
|
|
|
|
def stop(self):
|
|
"""Stop the spinner and clear its line. Safe to call multiple times."""
|
|
if self._stopped:
|
|
return
|
|
self._stopped = True
|
|
self._stop_event.set()
|
|
if self._thread:
|
|
self._thread.join()
|
|
sys.stdout.write("\r\033[K")
|
|
sys.stdout.flush()
|
|
|
|
def __enter__(self):
|
|
self.start()
|
|
return self
|
|
|
|
def __exit__(self, *_):
|
|
self.stop()
|
|
|
|
def _spin(self):
|
|
i = 0
|
|
while not self._stop_event.is_set():
|
|
frame = self.FRAMES[i % len(self.FRAMES)]
|
|
sys.stdout.write(f"\r{DIM}{frame} {self.message}...{RESET}")
|
|
sys.stdout.flush()
|
|
i += 1
|
|
self._stop_event.wait(0.08)
|
|
|
|
|
|
# ─── Chat Session ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class ChatSession:
|
|
def __init__(
|
|
self,
|
|
client: OpenAI,
|
|
model: str,
|
|
persona: str | None,
|
|
instructions: str | None,
|
|
temperature: float,
|
|
max_tokens: int,
|
|
tool_timeout: int,
|
|
reasoning_effort: str = "medium",
|
|
context_window: int = 131072,
|
|
):
|
|
self.client = client
|
|
self.model = model
|
|
self.persona = persona
|
|
self.instructions = instructions
|
|
self.temperature = temperature
|
|
self.max_tokens = max_tokens
|
|
self.tool_timeout = tool_timeout
|
|
self.reasoning_effort = reasoning_effort
|
|
self.context_window = context_window
|
|
self.show_reasoning = True
|
|
self.debug = False
|
|
self.auto_approve = False
|
|
self._session_id = uuid.uuid4().hex[:12]
|
|
self._read_files: set[str] = set()
|
|
self.md = MarkdownRenderer()
|
|
self.messages: list[dict] = []
|
|
self._last_usage: dict[str, int] | None = None
|
|
self._chars_per_token = 4.0 # calibrated from API usage
|
|
self._msg_tokens: list[int] = [] # parallel to self.messages
|
|
self._system_tokens = 0 # tokens for system_messages
|
|
self._assistant_pending_tokens = 0
|
|
self.creative_mode = False
|
|
self._init_system_messages()
|
|
|
|
def _init_system_messages(self):
|
|
"""Build the system/developer prefix messages.
|
|
|
|
System message format matches training distribution:
|
|
Persona: X (optional)
|
|
Knowledge cutoff: X (from base model pretraining)
|
|
Current date: X (from base model pretraining)
|
|
Reasoning: X (low/medium/high)
|
|
# Valid channels: ... (dynamic based on tools/reasoning)
|
|
Calls to these tools... (only when tools present)
|
|
|
|
Developer message uses # Instructions header when combined
|
|
with tool definitions (tool defs appended by chat template).
|
|
"""
|
|
from datetime import date
|
|
|
|
self.system_messages = []
|
|
today = date.today().strftime("%Y-%m-%d")
|
|
has_tools = not self.creative_mode
|
|
|
|
# -- Chat template kwargs --
|
|
# The chat template builds the system message (persona, knowledge
|
|
# cutoff, date, reasoning, channels) from kwargs. We pass persona
|
|
# via model_identity; reasoning_effort is already passed.
|
|
self._chat_template_kwargs_base = {
|
|
"reasoning_effort": self.reasoning_effort,
|
|
}
|
|
self._chat_template_kwargs = dict(self._chat_template_kwargs_base)
|
|
if self.persona:
|
|
self._chat_template_kwargs["model_identity"] = f"Persona: {self.persona}"
|
|
|
|
# -- Developer message --
|
|
if self.creative_mode:
|
|
dev_parts = [
|
|
"# Instructions",
|
|
"",
|
|
"You are a creative writing partner. Use the analysis channel to "
|
|
"think through structure, voice, and intent before drafting.",
|
|
"",
|
|
"Craft principles:",
|
|
"- Ground scenes in concrete sensory detail — what is seen, heard, felt.",
|
|
"- Vary rhythm. Short sentences hit hard. Longer ones carry the reader "
|
|
"through texture and nuance, building toward something.",
|
|
"- Dialogue should do at least two things: reveal character AND advance "
|
|
"plot or tension. Cut anything that's just exchanging information.",
|
|
"- Earn your abstractions. Don't say 'she felt sad' — show the thing "
|
|
"that makes the reader feel it.",
|
|
"- Trust subtext. Leave room for the reader.",
|
|
"",
|
|
"Match the user's genre and tone. If they want literary fiction, write "
|
|
"literary fiction. If they want pulp, write pulp with conviction. "
|
|
"Never condescend to the form.",
|
|
]
|
|
else:
|
|
dev_parts = [
|
|
"Always respond with tool calls, not just text.\n\n"
|
|
"TOOL PATTERNS:\n\n"
|
|
"Modify existing file → read_file then edit_file:\n"
|
|
" read_file(path='config.py') → "
|
|
"edit_file(path='config.py')\n\n"
|
|
"Create new file → write_file:\n"
|
|
" write_file(path='hello.py', content='...')\n\n"
|
|
"Find something across files → search:\n"
|
|
" search(query='test_')\n\n"
|
|
"Complex or multi-step task → plan first:\n"
|
|
" plan(prompt='refactor database from API')\n\n"
|
|
"Run a command, git, or tests → bash:\n"
|
|
" bash(command='git log -5')\n"
|
|
" bash(command='pytest')\n\n"
|
|
"Retrieve a URL → web_fetch:\n"
|
|
" web_fetch(url='https://example.com')\n\n"
|
|
"Look up documentation → man:\n"
|
|
" man(page='tar')",
|
|
]
|
|
if self.instructions:
|
|
dev_parts.append("")
|
|
dev_parts.append(self.instructions)
|
|
memories = _load_memories()
|
|
if memories:
|
|
dev_parts.append("")
|
|
dev_parts.append(
|
|
f"REMINDER: You currently have {len(memories)} memories stored. "
|
|
"Use recall to see them."
|
|
)
|
|
self.system_messages.append(
|
|
{"role": "developer", "content": "\n".join(dev_parts)}
|
|
)
|
|
# Agent prefix: system + developer only (no memories)
|
|
self._agent_system_messages = list(self.system_messages)
|
|
|
|
def _full_messages(self) -> list[dict]:
|
|
"""System messages + conversation history."""
|
|
return self.system_messages + self.messages
|
|
|
|
def send(self, user_input: str):
|
|
"""Send user input and handle the response loop (including tool calls)."""
|
|
self.messages.append({"role": "user", "content": user_input})
|
|
self._msg_tokens.append(max(1, int(len(user_input) / self._chars_per_token)))
|
|
_save_message(self._session_id, "user", user_input)
|
|
|
|
try:
|
|
while True:
|
|
msgs = self._full_messages()
|
|
|
|
if self.debug:
|
|
self._debug_print_request(msgs)
|
|
|
|
spinner = Spinner("Thinking")
|
|
spinner.start()
|
|
try:
|
|
stream = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=msgs,
|
|
**({"tools": TOOLS} if not self.creative_mode else {}),
|
|
max_completion_tokens=self.max_tokens,
|
|
temperature=self.temperature,
|
|
stream=True,
|
|
stream_options={"include_usage": True},
|
|
extra_body={
|
|
"chat_template_kwargs": self._chat_template_kwargs,
|
|
},
|
|
)
|
|
assistant_msg = self._stream_response(stream, spinner)
|
|
finally:
|
|
spinner.stop()
|
|
|
|
self._update_token_table(assistant_msg)
|
|
self.messages.append(assistant_msg)
|
|
self._msg_tokens.append(
|
|
self._assistant_pending_tokens
|
|
or max(
|
|
1,
|
|
int(
|
|
self._msg_char_count(assistant_msg) / self._chars_per_token
|
|
),
|
|
)
|
|
)
|
|
|
|
# Log assistant message to conversation history
|
|
content = assistant_msg.get("content", "")
|
|
tc = assistant_msg.get("tool_calls")
|
|
if content:
|
|
_save_message(self._session_id, "assistant", content)
|
|
if tc:
|
|
for call in tc:
|
|
fn = call.get("function", {})
|
|
name = fn.get("name", "")
|
|
if name not in (
|
|
"remember",
|
|
"forget",
|
|
"recall",
|
|
):
|
|
_save_message(
|
|
self._session_id,
|
|
"tool_call",
|
|
None,
|
|
name,
|
|
fn.get("arguments", ""),
|
|
)
|
|
|
|
tool_calls = assistant_msg.get("tool_calls")
|
|
if not tool_calls:
|
|
self._print_status_line()
|
|
# Auto-compact when prompt exceeds 80% of context window
|
|
if (
|
|
self._last_usage
|
|
and self._last_usage["prompt_tokens"]
|
|
> self.context_window * 0.8
|
|
):
|
|
print(
|
|
dim(
|
|
"\n[Auto-compacting: prompt exceeds "
|
|
"80% of context window]"
|
|
)
|
|
)
|
|
self._compact_messages(auto=True)
|
|
break
|
|
|
|
# Execute tool calls (potentially in parallel)
|
|
results, user_feedback = self._execute_tools(tool_calls)
|
|
# Map tool_call_id → tool name for logging
|
|
_tc_names = {
|
|
c["id"]: c.get("function", {}).get("name", "") for c in tool_calls
|
|
}
|
|
for tc_id, output in results:
|
|
tool_msg = {
|
|
"role": "tool",
|
|
"tool_call_id": tc_id,
|
|
"content": output,
|
|
}
|
|
self.messages.append(tool_msg)
|
|
self._msg_tokens.append(
|
|
max(1, int(len(output) / self._chars_per_token))
|
|
)
|
|
# Log tool result (skip memory tools to avoid noise)
|
|
_tname = _tc_names.get(tc_id, "")
|
|
if _tname not in (
|
|
"remember",
|
|
"forget",
|
|
"recall",
|
|
):
|
|
_save_message(
|
|
self._session_id,
|
|
"tool_result",
|
|
output[:2000],
|
|
_tname,
|
|
)
|
|
# Inject user feedback from approval prompt (e.g. "y, use full path")
|
|
if user_feedback:
|
|
self.messages.append({"role": "user", "content": user_feedback})
|
|
self._msg_tokens.append(
|
|
max(1, int(len(user_feedback) / self._chars_per_token))
|
|
)
|
|
except KeyboardInterrupt:
|
|
# Remove any partial tool results, then the originating assistant
|
|
# message with unanswered tool_calls — keep _msg_tokens in sync
|
|
while self.messages and self.messages[-1]["role"] == "tool":
|
|
self.messages.pop()
|
|
if self._msg_tokens:
|
|
self._msg_tokens.pop()
|
|
while (
|
|
self.messages
|
|
and self.messages[-1]["role"] == "assistant"
|
|
and self.messages[-1].get("tool_calls")
|
|
):
|
|
self.messages.pop()
|
|
if self._msg_tokens:
|
|
self._msg_tokens.pop()
|
|
raise
|
|
|
|
@staticmethod
|
|
def _strip_reasoning(text: str) -> str:
|
|
"""Remove <think>/<reasoning> tags and their content."""
|
|
for open_t, close_t in [
|
|
("<think>", "</think>"),
|
|
("<reasoning>", "</reasoning>"),
|
|
]:
|
|
while open_t in text:
|
|
start = text.find(open_t)
|
|
end = text.find(close_t, start)
|
|
if end != -1:
|
|
text = text[:start] + text[end + len(close_t) :]
|
|
else:
|
|
text = text[:start]
|
|
return text.strip()
|
|
|
|
# Tags that delimit reasoning blocks in content stream.
|
|
# Checked in order; first match wins.
|
|
_THINK_OPEN_TAGS = ("<think>", "<reasoning>")
|
|
_THINK_CLOSE_TAGS = ("</think>", "</reasoning>")
|
|
_MAX_TAG_LEN = max(len(t) for t in _THINK_OPEN_TAGS + _THINK_CLOSE_TAGS)
|
|
|
|
def _stream_response(self, stream, spinner: Spinner | None = None) -> dict:
|
|
"""Stream response, printing tokens as they arrive.
|
|
|
|
Handles two reasoning delivery mechanisms:
|
|
1. vLLM's `reasoning_content` field (when --reasoning-parser is set)
|
|
2. <think>...</think> tags in regular content (common default)
|
|
|
|
Stops the spinner on the first received delta.
|
|
|
|
Returns the complete assistant message as a dict suitable for
|
|
appending to self.messages.
|
|
"""
|
|
content_parts = []
|
|
reasoning_parts = []
|
|
tool_calls_acc: dict[int, dict] = {}
|
|
first_token = True
|
|
in_think = False # inside a <think>...</think> block
|
|
path1_reasoning = False # last reasoning came via reasoning_content field
|
|
pending = "" # buffer for partial tag detection
|
|
dim_active = False # tracks whether DIM ANSI code is active on terminal
|
|
|
|
def _set_dim(active: bool):
|
|
"""Ensure terminal DIM state matches desired state."""
|
|
nonlocal dim_active
|
|
if active and not dim_active:
|
|
if self.show_reasoning:
|
|
sys.stdout.write(f"{DIM}")
|
|
sys.stdout.flush()
|
|
dim_active = True
|
|
elif not active and dim_active:
|
|
sys.stdout.write(f"{RESET}")
|
|
sys.stdout.flush()
|
|
dim_active = False
|
|
|
|
def _flush_text(text: str, is_reasoning: bool):
|
|
"""Print text with appropriate styling."""
|
|
if not text:
|
|
return
|
|
if is_reasoning:
|
|
reasoning_parts.append(text)
|
|
if self.show_reasoning:
|
|
_set_dim(True)
|
|
sys.stdout.write(text)
|
|
sys.stdout.flush()
|
|
else:
|
|
_set_dim(False)
|
|
content_parts.append(text)
|
|
rendered = self.md.feed(text)
|
|
if rendered:
|
|
sys.stdout.write(rendered)
|
|
sys.stdout.flush()
|
|
|
|
def _drain_pending():
|
|
"""Process the pending buffer, flushing content and detecting tags."""
|
|
nonlocal pending, in_think
|
|
|
|
while pending:
|
|
if in_think:
|
|
# Look for any close tag
|
|
best_idx, best_tag = None, None
|
|
for tag in self._THINK_CLOSE_TAGS:
|
|
idx = pending.find(tag)
|
|
if idx != -1 and (best_idx is None or idx < best_idx):
|
|
best_idx, best_tag = idx, tag
|
|
|
|
if best_idx is not None:
|
|
_flush_text(pending[:best_idx], True)
|
|
pending = pending[best_idx + len(best_tag) :]
|
|
in_think = False
|
|
_set_dim(False)
|
|
if self.show_reasoning:
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
continue
|
|
|
|
# No close tag found — check if tail could be a partial tag
|
|
safe = len(pending) - self._MAX_TAG_LEN
|
|
if safe > 0:
|
|
_flush_text(pending[:safe], True)
|
|
pending = pending[safe:]
|
|
break
|
|
else:
|
|
# Look for any open tag
|
|
best_idx, best_tag = None, None
|
|
for tag in self._THINK_OPEN_TAGS:
|
|
idx = pending.find(tag)
|
|
if idx != -1 and (best_idx is None or idx < best_idx):
|
|
best_idx, best_tag = idx, tag
|
|
|
|
if best_idx is not None:
|
|
_flush_text(pending[:best_idx], False)
|
|
pending = pending[best_idx + len(best_tag) :]
|
|
in_think = True
|
|
_set_dim(True)
|
|
continue
|
|
|
|
# No open tag found — flush all but potential partial tag
|
|
safe = len(pending) - self._MAX_TAG_LEN
|
|
if safe > 0:
|
|
_flush_text(pending[:safe], False)
|
|
pending = pending[safe:]
|
|
break
|
|
|
|
def _stop_spinner_once():
|
|
"""Stop the spinner on first real content. Call is idempotent."""
|
|
nonlocal first_token
|
|
if first_token and spinner:
|
|
spinner.stop()
|
|
first_token = False
|
|
|
|
for chunk in stream:
|
|
# Capture usage from final chunk (stream_options.include_usage)
|
|
if hasattr(chunk, "usage") and chunk.usage is not None:
|
|
u = chunk.usage
|
|
pt = getattr(u, "prompt_tokens", None)
|
|
ct = getattr(u, "completion_tokens", None)
|
|
tt = getattr(u, "total_tokens", None)
|
|
if pt is not None and ct is not None:
|
|
self._last_usage = {
|
|
"prompt_tokens": pt,
|
|
"completion_tokens": ct,
|
|
"total_tokens": tt or (pt + ct),
|
|
}
|
|
if not chunk.choices:
|
|
continue
|
|
delta = chunk.choices[0].delta
|
|
|
|
if self.debug:
|
|
extras = dict(delta.model_extra) if delta.model_extra else {}
|
|
parts = []
|
|
if delta.role:
|
|
parts.append(f"role={delta.role}")
|
|
if delta.content:
|
|
parts.append(f"content={delta.content!r}")
|
|
if delta.tool_calls:
|
|
parts.append(f"tool_calls=...")
|
|
for k, v in extras.items():
|
|
if v is not None:
|
|
parts.append(f"{k}={v!r}")
|
|
if parts:
|
|
sys.stdout.write(f"{GRAY}[delta: {', '.join(parts)}]{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
# Path 1: reasoning field (vLLM sends as "reasoning" or "reasoning_content")
|
|
rc = getattr(delta, "reasoning", None) or getattr(
|
|
delta, "reasoning_content", None
|
|
)
|
|
if rc:
|
|
_stop_spinner_once()
|
|
reasoning_parts.append(rc)
|
|
in_think = True
|
|
path1_reasoning = True
|
|
if self.show_reasoning:
|
|
_set_dim(True)
|
|
sys.stdout.write(rc)
|
|
sys.stdout.flush()
|
|
continue
|
|
|
|
# Path 2: regular content (may contain <think> tags)
|
|
if delta.content:
|
|
_stop_spinner_once()
|
|
# Close reasoning dim if transitioning from Path 1 reasoning
|
|
if path1_reasoning:
|
|
path1_reasoning = False
|
|
in_think = False
|
|
_set_dim(False)
|
|
if self.show_reasoning:
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
pending += delta.content
|
|
_drain_pending()
|
|
|
|
# Handle tool call deltas
|
|
if delta.tool_calls:
|
|
_stop_spinner_once()
|
|
# Close reasoning dim if transitioning from reasoning
|
|
if in_think:
|
|
in_think = False
|
|
_set_dim(False)
|
|
if self.show_reasoning:
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
for tc_delta in delta.tool_calls:
|
|
idx = tc_delta.index
|
|
if idx not in tool_calls_acc:
|
|
tool_calls_acc[idx] = {
|
|
"id": "",
|
|
"type": "function",
|
|
"function": {"name": "", "arguments": ""},
|
|
}
|
|
tc = tool_calls_acc[idx]
|
|
if tc_delta.id:
|
|
tc["id"] = tc_delta.id
|
|
if tc_delta.function:
|
|
if tc_delta.function.name:
|
|
tc["function"]["name"] = tc_delta.function.name
|
|
if tc_delta.function.arguments:
|
|
tc["function"]["arguments"] += tc_delta.function.arguments
|
|
|
|
# Flush any remaining buffered text
|
|
if pending:
|
|
_flush_text(pending, in_think)
|
|
pending = ""
|
|
|
|
# Flush markdown renderer's remaining partial line and reset state
|
|
remainder = self.md.flush()
|
|
if remainder:
|
|
sys.stdout.write(remainder)
|
|
sys.stdout.flush()
|
|
self.md.in_code_block = False
|
|
|
|
# Close reasoning styling if still open
|
|
_set_dim(False)
|
|
if in_think and self.show_reasoning:
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
|
|
# Ensure newline after content
|
|
if content_parts:
|
|
sys.stdout.write("\n")
|
|
sys.stdout.flush()
|
|
|
|
# Build assistant message dict
|
|
msg: dict = {"role": "assistant"}
|
|
|
|
content = "".join(content_parts)
|
|
if content:
|
|
msg["content"] = content
|
|
else:
|
|
msg["content"] = None
|
|
|
|
if tool_calls_acc:
|
|
msg["tool_calls"] = [tool_calls_acc[i] for i in sorted(tool_calls_acc)]
|
|
|
|
return msg
|
|
|
|
_print_lock = threading.Lock()
|
|
|
|
# ── Debug ──────────────────────────────────────────────────────────
|
|
|
|
def _debug_print_request(self, msgs: list[dict]):
|
|
"""Print the full API request payload when debug mode is on."""
|
|
w = sys.stdout.write
|
|
w(f"\n{GRAY}{'─' * 60}{RESET}\n")
|
|
w(
|
|
f"{GRAY}[request] model={self.model} "
|
|
f"max_tokens={self.max_tokens} temp={self.temperature} "
|
|
f"reasoning={self.reasoning_effort} "
|
|
f"tools={0 if self.creative_mode else len(TOOLS)}{RESET}\n"
|
|
)
|
|
w(f"{GRAY}[request] {len(msgs)} messages:{RESET}\n")
|
|
for i, m in enumerate(msgs):
|
|
role = m["role"]
|
|
content = m.get("content") or ""
|
|
tool_calls = m.get("tool_calls")
|
|
tc_id = m.get("tool_call_id")
|
|
|
|
# Truncate long content for readability
|
|
if len(content) > 300:
|
|
display = (
|
|
content[:200] + f"...({len(content)} chars)..." + content[-50:]
|
|
)
|
|
else:
|
|
display = content
|
|
# Escape newlines for compact display
|
|
display = display.replace("\n", "\\n")
|
|
|
|
header = f" [{i}] {role}"
|
|
if tc_id:
|
|
header += f" (tool_call_id={tc_id})"
|
|
|
|
w(f"{GRAY}{header}: {display}{RESET}\n")
|
|
|
|
if tool_calls:
|
|
for tc in tool_calls:
|
|
name = tc.get("function", {}).get("name", "?")
|
|
args = tc.get("function", {}).get("arguments", "")
|
|
if len(args) > 200:
|
|
args = args[:150] + f"...({len(args)} chars)"
|
|
w(f"{GRAY} → {name}({args}){RESET}\n")
|
|
|
|
w(f"{GRAY}{'─' * 60}{RESET}\n\n")
|
|
sys.stdout.flush()
|
|
|
|
# ── Token tracking & status ────────────────────────────────────────
|
|
|
|
def _msg_char_count(self, msg: dict) -> int:
|
|
"""Count characters in a message, including tool call arguments."""
|
|
n = len(msg.get("content") or "")
|
|
for tc in msg.get("tool_calls", []):
|
|
n += len(tc.get("function", {}).get("name", ""))
|
|
n += len(tc.get("function", {}).get("arguments", ""))
|
|
return n
|
|
|
|
def _update_token_table(self, assistant_msg: dict):
|
|
"""Update per-message token estimates using API usage data."""
|
|
if not self._last_usage:
|
|
return
|
|
|
|
prompt_tok = self._last_usage["prompt_tokens"]
|
|
compl_tok = self._last_usage["completion_tokens"]
|
|
|
|
# Calibrate chars_per_token ratio from actual usage.
|
|
# Include tool definition chars (prompt_tok includes tool schema tokens).
|
|
all_msgs = self._full_messages() # system + self.messages (before append)
|
|
tool_def_chars = sum(len(json.dumps(t)) for t in TOOLS)
|
|
total_chars = sum(self._msg_char_count(m) for m in all_msgs) + tool_def_chars
|
|
if total_chars > 0 and prompt_tok > 0:
|
|
self._chars_per_token = total_chars / prompt_tok
|
|
|
|
# Compute system_tokens (stable after first call)
|
|
sys_chars = sum(self._msg_char_count(m) for m in self.system_messages)
|
|
self._system_tokens = max(1, int(sys_chars / self._chars_per_token))
|
|
|
|
# Re-estimate all message token counts with calibrated ratio
|
|
self._msg_tokens = [
|
|
max(1, int(self._msg_char_count(m) / self._chars_per_token))
|
|
for m in self.messages
|
|
]
|
|
|
|
# Stash completion_tokens for the assistant message about to be appended
|
|
self._assistant_pending_tokens = compl_tok
|
|
|
|
def _print_status_line(self):
|
|
"""Print a dim inline status line with token usage and settings."""
|
|
if not self._last_usage:
|
|
return
|
|
total_tok = (
|
|
self._last_usage["prompt_tokens"] + self._last_usage["completion_tokens"]
|
|
)
|
|
ctx = self.context_window
|
|
pct = total_tok / ctx * 100 if ctx > 0 else 0
|
|
parts = [f"{total_tok:,} / {ctx:,} tokens ({pct:.0f}%)"]
|
|
if self.reasoning_effort != "medium":
|
|
parts.append(f"reasoning: {self.reasoning_effort}")
|
|
sys.stdout.write(f"\n {DIM}[{' · '.join(parts)}]{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
# ── Conversation compaction ──────────────────────────────────────────
|
|
|
|
def _format_messages_for_summary(self, messages: list[dict]) -> str:
|
|
"""Format messages into a readable string for the summarization prompt."""
|
|
# Build tool_call_id → tool_name lookup for labeling tool results
|
|
tc_names: dict[str, str] = {}
|
|
for m in messages:
|
|
for tc in m.get("tool_calls", []):
|
|
tc_id = tc.get("id", "")
|
|
tc_name = tc.get("function", {}).get("name", "unknown")
|
|
if tc_id:
|
|
tc_names[tc_id] = tc_name
|
|
|
|
parts = []
|
|
for m in messages:
|
|
role = m["role"].upper()
|
|
content = m.get("content") or ""
|
|
|
|
if m.get("tool_calls"):
|
|
calls = []
|
|
for tc in m["tool_calls"]:
|
|
name = tc.get("function", {}).get("name", "?")
|
|
args = tc.get("function", {}).get("arguments", "")
|
|
calls.append(f"{name}({args})")
|
|
content += "\n[Called: " + ", ".join(calls) + "]"
|
|
|
|
# Label tool results with the tool name
|
|
if role == "TOOL":
|
|
tc_id = m.get("tool_call_id", "")
|
|
name = tc_names.get(tc_id, "tool")
|
|
role = f"TOOL[{name}]"
|
|
|
|
if content:
|
|
if len(content) > 2000:
|
|
content = content[:1000] + "\n...[truncated]...\n" + content[-500:]
|
|
parts.append(f"{role}: {content}")
|
|
return "\n\n".join(parts)
|
|
|
|
def _compact_messages(self, auto: bool = False):
|
|
"""Compact conversation history by summarizing all messages.
|
|
|
|
Summarizes the entire conversation via a separate model call,
|
|
budget-fitted to 80% of the context window.
|
|
|
|
When auto=True (triggered by context limit), appends a continuation
|
|
hint with the last user message so the model can resume seamlessly.
|
|
"""
|
|
if len(self.messages) < 2:
|
|
print(dim("Not enough messages to compact."))
|
|
return
|
|
|
|
# Find the last user message for the continuation hint
|
|
last_user_content = None
|
|
if auto:
|
|
for m in reversed(self.messages):
|
|
if m["role"] == "user":
|
|
last_user_content = m.get("content") or ""
|
|
break
|
|
|
|
to_summarize = self.messages
|
|
|
|
# Budget: fit as many messages as possible into summary request
|
|
summary_max_tokens = 4096
|
|
prompt_budget = (
|
|
int(self.context_window * 0.8) - summary_max_tokens - self._system_tokens
|
|
)
|
|
selected = []
|
|
running = 0
|
|
for i, msg in enumerate(to_summarize):
|
|
msg_tok = (
|
|
self._msg_tokens[i]
|
|
if i < len(self._msg_tokens)
|
|
else max(1, int(self._msg_char_count(msg) / self._chars_per_token))
|
|
)
|
|
if running + msg_tok > prompt_budget:
|
|
break
|
|
selected.append(msg)
|
|
running += msg_tok
|
|
|
|
if not selected:
|
|
print(dim("Messages too large to fit in summary context."))
|
|
return
|
|
|
|
# Build summary prompt and call model
|
|
formatted = self._format_messages_for_summary(selected)
|
|
summary_msgs = [
|
|
{
|
|
"role": "developer",
|
|
"content": (
|
|
"# Conversation Compactor\n\n"
|
|
"Your output REPLACES the conversation history — the assistant "
|
|
"will continue from your summary with no access to the original messages.\n\n"
|
|
"1. **Output format** — use these exact sections, omit any that are empty:\n"
|
|
" - **## Decisions**: Choices made (architecture, libraries, approaches).\n"
|
|
" - **## Files**: Files read, created, or modified, with brief notes.\n"
|
|
" - **## Key code**: Exact function names, class names, variable names, "
|
|
"and short code snippets the assistant will need. "
|
|
"Preserve identifiers verbatim — do NOT paraphrase.\n"
|
|
" - **## Tool results**: Important tool outputs (errors, search matches, "
|
|
"file contents) that inform ongoing work.\n"
|
|
" - **## Open tasks**: What the user asked for that is not yet done, "
|
|
"with enough context to continue.\n"
|
|
" - **## User preferences**: Workflow preferences, constraints, or "
|
|
"instructions the user stated.\n\n"
|
|
"2. **Density rules:**\n"
|
|
" - Every token should carry information.\n"
|
|
" - Preserve exact paths, identifiers, and numbers — never paraphrase these.\n"
|
|
" - Omit pleasantries, acknowledgments, and reasoning that led to dead ends.\n"
|
|
" - If a tool call's result was an error that was later resolved, "
|
|
"keep only the resolution.\n\n"
|
|
"3. **Common mistakes to avoid:**\n"
|
|
" - Paraphrasing file paths, function names, or variable names\n"
|
|
" - Including dead-end explorations or superseded decisions\n"
|
|
" - Omitting the open tasks section when work remains\n"
|
|
" - Being verbose — this is a summary, not a transcript"
|
|
),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": ("Compact the following conversation:\n\n" + formatted),
|
|
},
|
|
]
|
|
|
|
spinner = Spinner("Compacting")
|
|
spinner.start()
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=summary_msgs,
|
|
max_completion_tokens=summary_max_tokens,
|
|
temperature=0.3,
|
|
stream=False,
|
|
extra_body={
|
|
"chat_template_kwargs": {
|
|
**self._chat_template_kwargs_base,
|
|
"reasoning_effort": "low",
|
|
}
|
|
},
|
|
)
|
|
summary = response.choices[0].message.content or ""
|
|
# Strip any <think>/<reasoning> tags the summarizer may emit
|
|
summary = self._strip_reasoning(summary)
|
|
except Exception as e:
|
|
spinner.stop()
|
|
print(red(f"Compaction failed: {e}"))
|
|
return
|
|
finally:
|
|
spinner.stop()
|
|
|
|
# Append continuation hint for auto-compact
|
|
if last_user_content:
|
|
# Truncate very long user messages
|
|
if len(last_user_content) > 500:
|
|
last_user_content = last_user_content[:400] + "..."
|
|
summary += (
|
|
f"\n\n## Continue\n"
|
|
f"The user's last message was: {last_user_content}\n"
|
|
f"Continue assisting from where we left off."
|
|
)
|
|
|
|
# Replace messages
|
|
before_tokens = self._system_tokens + sum(self._msg_tokens)
|
|
summary_user = {"role": "user", "content": "[Conversation summary]"}
|
|
summary_asst = {"role": "assistant", "content": summary}
|
|
self.messages = [summary_user, summary_asst]
|
|
# File contents are gone after compaction — force re-read before edit_file
|
|
self._read_files.clear()
|
|
|
|
# Rebuild token table
|
|
su_tok = max(1, int(self._msg_char_count(summary_user) / self._chars_per_token))
|
|
sa_tok = max(1, int(self._msg_char_count(summary_asst) / self._chars_per_token))
|
|
self._msg_tokens = [su_tok, sa_tok]
|
|
after_tokens = self._system_tokens + sum(self._msg_tokens)
|
|
|
|
print(dim(f"[compacted: ~{before_tokens:,} \u2192 ~{after_tokens:,} tokens]"))
|
|
print(dim("─" * 60))
|
|
for line in summary.splitlines():
|
|
print(dim(f" {line}"))
|
|
print(dim("─" * 60))
|
|
|
|
# ── Two-phase tool execution ────────────────────────────────────────
|
|
#
|
|
# Phase 1 — prepare: parse args, validate, build preview text (serial)
|
|
# Phase 2 — approve: display all previews, single prompt (serial)
|
|
# Phase 3 — execute: run approved tools (parallel if multiple)
|
|
|
|
def _execute_tools(
|
|
self, tool_calls: list[dict]
|
|
) -> tuple[list[tuple[str, str]], str | None]:
|
|
"""Execute tool calls with batch preview and approval.
|
|
|
|
Returns (results, user_feedback) where user_feedback is an optional
|
|
message the user typed alongside their approval (e.g. "y, use full path").
|
|
"""
|
|
# Phase 1: prepare all tool calls
|
|
items = [self._prepare_tool(tc) for tc in tool_calls]
|
|
|
|
# Phase 2: display previews and prompt
|
|
user_feedback = self._display_and_approve(items)
|
|
|
|
# Phase 3: execute
|
|
def run_one(item: dict) -> tuple[str, str]:
|
|
if item.get("error"):
|
|
return item["call_id"], item["error"]
|
|
if item.get("denied"):
|
|
return item["call_id"], item.get("denial_msg", "Denied by user")
|
|
return item["execute"](item)
|
|
|
|
if len(items) == 1:
|
|
results = [run_one(items[0])]
|
|
else:
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as pool:
|
|
results = list(pool.map(run_one, items))
|
|
|
|
# Post-plan gate: prompt user on main thread after plan completes
|
|
for i, item in enumerate(items):
|
|
if (
|
|
item.get("func_name") == "plan"
|
|
and not item.get("error")
|
|
and not item.get("denied")
|
|
and not self.auto_approve
|
|
):
|
|
cid, output = results[i]
|
|
# Show the plan content so user can review before deciding
|
|
sys.stdout.write(f"\n{DIM}{'─' * 60}{RESET}\n")
|
|
for line in output.splitlines():
|
|
sys.stdout.write(f" {line}\n")
|
|
sys.stdout.write(f"{DIM}{'─' * 60}{RESET}\n")
|
|
sys.stdout.flush()
|
|
try:
|
|
prompt_text = (
|
|
f" \001{BOLD}\002Plan ready.\001{RESET}\002 "
|
|
f"\001{DIM}\002[enter to approve, or give feedback]"
|
|
f"\001{RESET}\002 "
|
|
)
|
|
resp = input(prompt_text).strip()
|
|
except EOFError:
|
|
resp = ""
|
|
except KeyboardInterrupt:
|
|
resp = "reject"
|
|
if resp.lower() in ("n", "no", "reject"):
|
|
output += (
|
|
"\n\n---\nUser REJECTED this plan. Do not proceed "
|
|
"with implementation. Ask the user what they want instead."
|
|
)
|
|
elif resp:
|
|
output += f"\n\n---\nUser feedback on this plan: {resp}"
|
|
results[i] = (cid, output)
|
|
|
|
return results, user_feedback
|
|
|
|
def _prepare_tool(self, tc: dict) -> dict:
|
|
"""Parse a tool call and prepare preview info for display."""
|
|
call_id = tc["id"]
|
|
func_name = tc["function"]["name"]
|
|
raw_args = tc["function"]["arguments"]
|
|
|
|
# Map tool name → primary argument key for bare-string fallback
|
|
_primary_key = {
|
|
"bash": "command",
|
|
"math": "code",
|
|
"read_file": "path",
|
|
"search": "query",
|
|
"write_file": "content",
|
|
"edit_file": "old_string",
|
|
"man": "page",
|
|
"web_fetch": "url",
|
|
"web_search": "query",
|
|
"task": "prompt",
|
|
"plan": "prompt",
|
|
"remember": "key",
|
|
"recall": "query",
|
|
"forget": "key",
|
|
}
|
|
|
|
try:
|
|
args = json.loads(raw_args)
|
|
except json.JSONDecodeError as exc:
|
|
args = None
|
|
# Fallback 1: regex-extract a known key from malformed JSON
|
|
for key in (
|
|
"command",
|
|
"code",
|
|
"content",
|
|
"page",
|
|
"path",
|
|
"pattern",
|
|
"prompt",
|
|
"query",
|
|
"url",
|
|
):
|
|
m = re.search(rf'"{key}"\s*:\s*"((?:[^"\\]|\\.)*)"', raw_args)
|
|
if m:
|
|
try:
|
|
val = json.loads('"' + m.group(1) + '"')
|
|
except (json.JSONDecodeError, Exception):
|
|
val = m.group(1)
|
|
args = {key: val}
|
|
break
|
|
# Fallback 2: bare string (no JSON wrapper at all)
|
|
if (
|
|
args is None
|
|
and raw_args.strip()
|
|
and not raw_args.strip().startswith("{")
|
|
):
|
|
pk = _primary_key.get(func_name)
|
|
if pk:
|
|
args = {pk: raw_args}
|
|
if args is None:
|
|
preview = raw_args[:300] + ("..." if len(raw_args) > 300 else "")
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": func_name,
|
|
"header": f"✗ {func_name}: {exc}",
|
|
"preview": f" {RED}{preview}{RESET}",
|
|
"needs_approval": False,
|
|
"error": f"JSON parse error: {exc}\nRaw arguments: {raw_args[:500]}",
|
|
}
|
|
|
|
preparers = {
|
|
"bash": self._prepare_bash,
|
|
"read_file": self._prepare_read_file,
|
|
"search": self._prepare_search,
|
|
"write_file": self._prepare_write_file,
|
|
"edit_file": self._prepare_edit_file,
|
|
"math": self._prepare_math,
|
|
"man": self._prepare_man,
|
|
"web_fetch": self._prepare_web_fetch,
|
|
"web_search": self._prepare_web_search,
|
|
"task": self._prepare_task,
|
|
"plan": self._prepare_plan,
|
|
"remember": self._prepare_remember,
|
|
"recall": self._prepare_recall,
|
|
"forget": self._prepare_forget,
|
|
}
|
|
preparer = preparers.get(func_name)
|
|
if not preparer:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": func_name,
|
|
"header": f"✗ Unknown tool: {func_name}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Unknown tool: {func_name}",
|
|
}
|
|
return preparer(call_id, args)
|
|
|
|
def _display_and_approve(self, items: list[dict]) -> str | None:
|
|
"""Display all tool previews and prompt for batch approval.
|
|
|
|
Returns an optional user feedback message (text after y/n),
|
|
or None if no feedback was given.
|
|
"""
|
|
pending = [
|
|
it for it in items if it.get("needs_approval") and not it.get("error")
|
|
]
|
|
|
|
with self._print_lock:
|
|
# Print all headers and previews
|
|
for item in items:
|
|
if item.get("error"):
|
|
sys.stdout.write(f" {red(item['header'])}\n")
|
|
else:
|
|
sys.stdout.write(f" {yellow(item['header'])}\n")
|
|
if item.get("preview"):
|
|
sys.stdout.write(item["preview"] + "\n")
|
|
sys.stdout.flush()
|
|
|
|
if not pending or self.auto_approve:
|
|
return None
|
|
|
|
# Prompt
|
|
try:
|
|
if len(pending) == 1:
|
|
label = pending[0].get("approval_label", pending[0]["func_name"])
|
|
prompt_text = (
|
|
f" \001{BOLD}\002Allow {label}?\001{RESET}\002 "
|
|
f"\001{DIM}\002[y/n/a(lways), optional message]\001{RESET}\002 "
|
|
)
|
|
else:
|
|
labels = ", ".join(
|
|
it.get("approval_label", it["func_name"]) for it in pending
|
|
)
|
|
prompt_text = (
|
|
f" \001{BOLD}\002Allow {len(pending)} tools ({labels})?\001{RESET}\002 "
|
|
f"\001{DIM}\002[y/n/a(lways), optional message]\001{RESET}\002 "
|
|
)
|
|
resp = input(prompt_text).strip()
|
|
except (EOFError, KeyboardInterrupt):
|
|
sys.stdout.write("\n")
|
|
resp = "n"
|
|
|
|
# Parse decision and optional feedback: "y, use absolute path"
|
|
decision = resp.lower()
|
|
feedback = None
|
|
for sep in (",", " "):
|
|
if sep in resp:
|
|
decision = resp[: resp.index(sep)].strip().lower()
|
|
feedback = resp[resp.index(sep) + 1 :].strip() or None
|
|
break
|
|
|
|
if decision in ("a", "always"):
|
|
self.auto_approve = True
|
|
elif decision not in ("y", "yes"):
|
|
denial_msg = "Denied by user"
|
|
if feedback:
|
|
denial_msg += f": {feedback}"
|
|
for item in pending:
|
|
item["denied"] = True
|
|
item["denial_msg"] = denial_msg
|
|
return None # feedback already in denial_msg
|
|
|
|
return feedback
|
|
|
|
# ── Prepare methods (build preview, validate, no side effects) ────
|
|
|
|
def _prepare_bash(self, call_id: str, args: dict) -> dict:
|
|
command = _sanitize_command(args.get("command", ""))
|
|
if not command:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "bash",
|
|
"header": "✗ bash: empty command",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: empty command",
|
|
}
|
|
blocked = is_command_blocked(command)
|
|
if blocked:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "bash",
|
|
"header": f"✗ {blocked}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": blocked,
|
|
}
|
|
display_cmd = command.split("\n")[0]
|
|
if "\n" in command:
|
|
display_cmd += f" ... ({command.count(chr(10))} more lines)"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "bash",
|
|
"header": f"⚙ bash: {display_cmd}",
|
|
"preview": "",
|
|
"needs_approval": True,
|
|
"approval_label": "bash",
|
|
"execute": self._exec_bash,
|
|
"command": command,
|
|
}
|
|
|
|
def _prepare_read_file(self, call_id: str, args: dict) -> dict:
|
|
path = args.get("path", "")
|
|
if not path:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "read_file",
|
|
"header": "✗ read_file: missing path",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: missing path",
|
|
}
|
|
path = os.path.expanduser(path)
|
|
resolved = os.path.realpath(path)
|
|
offset = args.get("offset") # 1-based line number, or None
|
|
limit = args.get("limit") # max lines, or None
|
|
# Coerce to int safely (model may send strings or floats)
|
|
try:
|
|
if offset is not None:
|
|
offset = int(offset)
|
|
if limit is not None:
|
|
limit = int(limit)
|
|
except (ValueError, TypeError):
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "read_file",
|
|
"header": "✗ read_file: invalid offset/limit",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": (
|
|
f"Error: offset/limit must be integers "
|
|
f"(got offset={args.get('offset')!r}, "
|
|
f"limit={args.get('limit')!r})"
|
|
),
|
|
}
|
|
if offset is not None and offset < 1:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "read_file",
|
|
"header": "✗ read_file: offset must be >= 1",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error: offset must be >= 1 (got {offset})",
|
|
}
|
|
if limit is not None and limit < 1:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "read_file",
|
|
"header": "✗ read_file: limit must be >= 1",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error: limit must be >= 1 (got {limit})",
|
|
}
|
|
# Register early so a same-batch edit_file can pass the read guard.
|
|
# If the file doesn't exist, _exec_read_file returns an error and
|
|
# _exec_edit_file's re-read will also fail naturally.
|
|
self._read_files.add(resolved)
|
|
# Build header showing range if specified
|
|
header = f"⚙ read_file: {path}"
|
|
if offset is not None or limit is not None:
|
|
start = offset or 1
|
|
if limit is not None:
|
|
header += f" (lines {start}-{start + limit - 1})"
|
|
else:
|
|
header += f" (from line {start})"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "read_file",
|
|
"header": header,
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"execute": self._exec_read_file,
|
|
"path": path,
|
|
"offset": offset,
|
|
"limit": limit,
|
|
}
|
|
|
|
def _prepare_search(self, call_id: str, args: dict) -> dict:
|
|
pattern = args.get("query", "")
|
|
if not pattern:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "search",
|
|
"header": "✗ search: missing query",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: missing query",
|
|
}
|
|
path = os.path.expanduser(args.get("path", "") or ".")
|
|
preview = f" {DIM}/{pattern}/ in {path}{RESET}"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "search",
|
|
"header": f"⚙ search: /{pattern}/ in {path}",
|
|
"preview": preview,
|
|
"needs_approval": False,
|
|
"execute": self._exec_search,
|
|
"pattern": pattern,
|
|
"path": path,
|
|
}
|
|
|
|
def _prepare_write_file(self, call_id: str, args: dict) -> dict:
|
|
path = args.get("path", "")
|
|
content = args.get("content", "")
|
|
if not path:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "write_file",
|
|
"header": "✗ write_file: missing path",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: missing path",
|
|
}
|
|
path = os.path.expanduser(path)
|
|
resolved = os.path.realpath(path)
|
|
exists = os.path.exists(resolved)
|
|
is_overwrite = exists and resolved not in self._read_files
|
|
|
|
# Build preview
|
|
preview_parts = []
|
|
if is_overwrite:
|
|
preview_parts.append(
|
|
f" {YELLOW}Warning: overwriting existing file not previously read{RESET}"
|
|
)
|
|
text = content[:500]
|
|
if len(content) > 500:
|
|
text += f"\n... ({len(content)} chars total)"
|
|
preview_parts.append(f"{DIM}{textwrap.indent(text, ' ')}{RESET}")
|
|
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "write_file",
|
|
"header": f"⚙ write_file: {path} ({len(content)} chars)",
|
|
"preview": "\n".join(preview_parts),
|
|
"needs_approval": True,
|
|
"approval_label": "overwrite_file" if is_overwrite else "write_file",
|
|
"execute": self._exec_write_file,
|
|
"path": path,
|
|
"resolved": resolved,
|
|
"content": content,
|
|
}
|
|
|
|
def _prepare_edit_file(self, call_id: str, args: dict) -> dict:
|
|
path = args.get("path", "")
|
|
old_string = args.get("old_string", "")
|
|
new_string = args.get("new_string", "")
|
|
near_line = args.get("near_line")
|
|
if isinstance(near_line, str):
|
|
try:
|
|
near_line = int(near_line)
|
|
except ValueError:
|
|
near_line = None
|
|
if not path:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": "✗ edit_file: missing path",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: missing path",
|
|
}
|
|
if not old_string:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": "✗ edit_file: missing old_string",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: missing old_string",
|
|
}
|
|
path = os.path.expanduser(path)
|
|
resolved = os.path.realpath(path)
|
|
|
|
if resolved not in self._read_files:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"✗ edit_file: {path}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error: must read_file {path} before editing it",
|
|
}
|
|
|
|
# Pre-read to validate and build diff preview
|
|
try:
|
|
with open(path, "r") as f:
|
|
content = f.read()
|
|
occurrences = _find_occurrences(content, old_string)
|
|
if len(occurrences) == 0:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"✗ edit_file: {path}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error: old_string not found in {path}",
|
|
}
|
|
if len(occurrences) > 1 and near_line is None:
|
|
line_list = ", ".join(str(ln) for ln in occurrences)
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"✗ edit_file: {path}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": (
|
|
f"Error: old_string found {len(occurrences)} times "
|
|
f"at lines {line_list} — use near_line to pick one"
|
|
),
|
|
}
|
|
except FileNotFoundError:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"✗ edit_file: {path}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error: {path} not found",
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"✗ edit_file: {path}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": f"Error editing {path}: {e}",
|
|
}
|
|
|
|
# Build diff preview
|
|
preview_parts = []
|
|
old_preview = old_string[:200] + ("..." if len(old_string) > 200 else "")
|
|
new_preview = new_string[:200] + ("..." if len(new_string) > 200 else "")
|
|
for line in old_preview.splitlines():
|
|
preview_parts.append(f" {RED}- {line}{RESET}")
|
|
if new_string:
|
|
for line in new_preview.splitlines():
|
|
preview_parts.append(f" {GREEN}+ {line}{RESET}")
|
|
else:
|
|
preview_parts.append(
|
|
f" {YELLOW}(deletion — {len(old_string)} chars removed){RESET}"
|
|
)
|
|
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "edit_file",
|
|
"header": f"⚙ edit_file: {path}",
|
|
"preview": "\n".join(preview_parts),
|
|
"needs_approval": True,
|
|
"approval_label": "edit_file",
|
|
"execute": self._exec_edit_file,
|
|
"path": path,
|
|
"resolved": resolved,
|
|
"old_string": old_string,
|
|
"new_string": new_string,
|
|
"near_line": near_line,
|
|
}
|
|
|
|
def _prepare_math(self, call_id: str, args: dict) -> dict:
|
|
code = args.get("code", "")
|
|
if isinstance(code, list):
|
|
code = "\n".join(code)
|
|
if not code:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "math",
|
|
"header": "✗ math: empty code",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: no code provided",
|
|
}
|
|
# Show code preview
|
|
display = code[:300]
|
|
if len(code) > 300:
|
|
display += f"\n... ({len(code)} chars total)"
|
|
preview = f"{DIM}{textwrap.indent(display, ' ')}{RESET}"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "math",
|
|
"header": f"⚙ math: ({len(code)} chars)",
|
|
"preview": preview,
|
|
"needs_approval": True,
|
|
"approval_label": "math",
|
|
"execute": self._exec_math,
|
|
"code": code,
|
|
}
|
|
|
|
def _prepare_man(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a man/info page lookup."""
|
|
page = (args.get("page") or "").strip()
|
|
if not page:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "man",
|
|
"header": "✗ man: empty page",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: no page name provided",
|
|
}
|
|
# Sanitize: only allow alphanumeric, dash, underscore, dot
|
|
if not re.match(r"^[a-zA-Z0-9._-]+$", page):
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "man",
|
|
"header": "✗ man: invalid page name",
|
|
"preview": f" {RED}{page}{RESET}",
|
|
"needs_approval": False,
|
|
"error": f"Error: invalid page name {page!r}",
|
|
}
|
|
section = (args.get("section") or "").strip()
|
|
if section and not re.match(r"^[1-9][a-z]?$", section):
|
|
section = ""
|
|
label = f"{page}({section})" if section else page
|
|
preview = f" {DIM}{label}{RESET}"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "man",
|
|
"header": f"⚙ man: {label}",
|
|
"preview": preview,
|
|
"needs_approval": False,
|
|
"execute": self._exec_man,
|
|
"page": page,
|
|
"section": section,
|
|
}
|
|
|
|
def _prepare_web_fetch(self, call_id: str, args: dict) -> dict:
|
|
url = args.get("url", "").strip()
|
|
question = args.get("question", "").strip()
|
|
if not url:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_fetch",
|
|
"header": "✗ web_fetch: empty url",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: no URL provided",
|
|
}
|
|
if not question:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_fetch",
|
|
"header": "✗ web_fetch: empty question",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: no question provided",
|
|
}
|
|
if not url.startswith(("http://", "https://")):
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_fetch",
|
|
"header": "✗ web_fetch: invalid url",
|
|
"preview": f" {RED}{url}{RESET}",
|
|
"needs_approval": False,
|
|
"error": f"Error: URL must start with http:// or https:// (got {url!r})",
|
|
}
|
|
# SSRF protection: reject private/link-local/metadata IPs
|
|
try:
|
|
hostname = urlparse(url).hostname
|
|
if hostname:
|
|
addr = ipaddress.ip_address(socket.gethostbyname(hostname))
|
|
if addr.is_private or addr.is_loopback or addr.is_link_local:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_fetch",
|
|
"header": "✗ web_fetch: blocked (private network)",
|
|
"preview": f" {RED}{url}{RESET}",
|
|
"needs_approval": False,
|
|
"error": f"Error: URL resolves to private/internal address ({addr})",
|
|
}
|
|
except (socket.gaierror, ValueError):
|
|
pass # DNS failure handled later during actual fetch
|
|
q_preview = question[:200] + ("..." if len(question) > 200 else "")
|
|
preview = f" {DIM}{url}\n Q: {q_preview}{RESET}"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_fetch",
|
|
"header": f"⚙ web_fetch: {url[:80]}",
|
|
"preview": preview,
|
|
"needs_approval": True,
|
|
"approval_label": "web_fetch",
|
|
"execute": self._exec_web_fetch,
|
|
"url": url,
|
|
"question": question,
|
|
}
|
|
|
|
def _prepare_web_search(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a web search via Tavily for approval."""
|
|
query = (args.get("query") or "").strip()
|
|
if not query:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_search",
|
|
"header": "✗ web_search: empty query",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: no query provided",
|
|
}
|
|
if not _get_tavily_key():
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_search",
|
|
"header": "✗ web_search: no API key",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": (
|
|
"Error: Tavily API key not configured. "
|
|
"Set it in ~/.config/pcode/tavily_key or $TAVILY_API_KEY. "
|
|
"Use web_fetch with a direct URL as an alternative."
|
|
),
|
|
}
|
|
try:
|
|
max_results = min(max(int(args.get("max_results") or 5), 1), 20)
|
|
except (ValueError, TypeError):
|
|
max_results = 5
|
|
topic = args.get("topic", "general") or "general"
|
|
if topic not in ("general", "news", "finance"):
|
|
topic = "general"
|
|
q_preview = query[:200] + ("..." if len(query) > 200 else "")
|
|
preview = f" {DIM}{q_preview}{RESET}"
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "web_search",
|
|
"header": f"⚙ web_search: {query[:80]}",
|
|
"preview": preview,
|
|
"needs_approval": True,
|
|
"approval_label": "web_search",
|
|
"execute": self._exec_web_search,
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"topic": topic,
|
|
}
|
|
|
|
def _prepare_task(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a general-purpose sub-agent task for approval."""
|
|
prompt = (args.get("prompt") or "").strip()
|
|
if not prompt:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "task",
|
|
"header": "✗ task: empty prompt",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: empty prompt",
|
|
}
|
|
preview_text = prompt[:300] + ("..." if len(prompt) > 300 else "")
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "task",
|
|
"header": "⚙ task (autonomous agent)",
|
|
"preview": f" {DIM}{preview_text}{RESET}",
|
|
"needs_approval": True,
|
|
"approval_label": "task",
|
|
"execute": self._exec_task,
|
|
"prompt": prompt,
|
|
}
|
|
|
|
def _prepare_plan(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a planning agent for approval."""
|
|
prompt = (args.get("prompt") or "").strip()
|
|
if not prompt:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "plan",
|
|
"header": "✗ plan: empty prompt",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: empty prompt",
|
|
}
|
|
preview_text = prompt[:300] + ("..." if len(prompt) > 300 else "")
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "plan",
|
|
"header": "⚙ plan (planning agent)",
|
|
"preview": f" {DIM}{preview_text}{RESET}",
|
|
"needs_approval": True,
|
|
"approval_label": "plan",
|
|
"execute": self._exec_plan,
|
|
"prompt": prompt,
|
|
}
|
|
|
|
def _prepare_remember(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a remember (save memory) action."""
|
|
key = _normalize_key((args.get("key") or "").strip())
|
|
value = (args.get("value") or "").strip()
|
|
if not key or not value:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "remember",
|
|
"header": "✗ remember: requires key and value",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: both 'key' and 'value' are required",
|
|
}
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "remember",
|
|
"header": f"⚙ remember: {key}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"execute": self._exec_remember,
|
|
"key": key,
|
|
"value": value,
|
|
}
|
|
|
|
def _prepare_forget(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a forget (delete memory) action."""
|
|
key = _normalize_key((args.get("key") or "").strip())
|
|
if not key:
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "forget",
|
|
"header": "✗ forget: empty key",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"error": "Error: key is required",
|
|
}
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "forget",
|
|
"header": f"⚙ forget: {key}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"execute": self._exec_forget,
|
|
"key": key,
|
|
}
|
|
|
|
def _prepare_recall(self, call_id: str, args: dict) -> dict:
|
|
"""Prepare a recall action."""
|
|
query = (args.get("query") or "").strip()
|
|
limit = args.get("limit", 20)
|
|
if isinstance(limit, str):
|
|
try:
|
|
limit = int(limit)
|
|
except ValueError:
|
|
limit = 20
|
|
return {
|
|
"call_id": call_id,
|
|
"func_name": "recall",
|
|
"header": f"⚙ recall{': ' + query[:80] if query else ''}",
|
|
"preview": "",
|
|
"needs_approval": False,
|
|
"execute": self._exec_recall,
|
|
"query": query,
|
|
"limit": min(limit, 50),
|
|
}
|
|
|
|
# ── Execute methods (do the work, display output) ─────────────────
|
|
|
|
def _exec_bash(self, item: dict) -> tuple[str, str]:
|
|
"""Execute a bash command via temp script."""
|
|
call_id, command = item["call_id"], item["command"]
|
|
try:
|
|
with tempfile.NamedTemporaryFile(mode="w", suffix=".sh", delete=False) as f:
|
|
f.write(command)
|
|
script_path = f.name
|
|
try:
|
|
result = subprocess.run(
|
|
["bash", script_path],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=self.tool_timeout,
|
|
)
|
|
finally:
|
|
os.unlink(script_path)
|
|
output = result.stdout
|
|
if result.stderr:
|
|
output += ("\n" if output else "") + result.stderr
|
|
output = output.strip()
|
|
original_len = len(output)
|
|
|
|
if original_len > 10_000:
|
|
output = (
|
|
output[:5_000]
|
|
+ f"\n\n... [{original_len - 10_000} chars truncated] ...\n\n"
|
|
+ output[-5_000:]
|
|
)
|
|
|
|
with self._print_lock:
|
|
preview = output[:500]
|
|
if original_len > 500:
|
|
preview += f"\n ... ({original_len} chars total)"
|
|
if preview:
|
|
indented = textwrap.indent(preview, " ")
|
|
sys.stdout.write(f"{DIM}{indented}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
if result.returncode != 0:
|
|
output += f"\n[exit code: {result.returncode}]"
|
|
|
|
return call_id, output if output else "(no output)"
|
|
|
|
except subprocess.TimeoutExpired:
|
|
msg = f"Command timed out after {self.tool_timeout}s"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
except Exception as e:
|
|
msg = f"Error executing command: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
|
|
def _exec_read_file(self, item: dict) -> tuple[str, str]:
|
|
"""Read a file and return numbered lines, optionally sliced."""
|
|
call_id, path = item["call_id"], item["path"]
|
|
offset = item.get("offset") # 1-based, or None
|
|
limit = item.get("limit") # max lines, or None
|
|
resolved = os.path.realpath(path)
|
|
|
|
try:
|
|
with open(path, "r") as f:
|
|
all_lines = f.readlines()
|
|
except FileNotFoundError:
|
|
self._read_files.discard(resolved)
|
|
return call_id, f"Error: {path} not found"
|
|
except Exception as e:
|
|
self._read_files.discard(resolved)
|
|
return call_id, f"Error reading {path}: {e}"
|
|
|
|
self._read_files.add(resolved)
|
|
total_lines = len(all_lines)
|
|
|
|
# Slice if offset/limit specified
|
|
start = max(1, offset or 1)
|
|
if limit is not None:
|
|
lines = all_lines[start - 1 : start - 1 + limit]
|
|
else:
|
|
lines = all_lines[start - 1 :]
|
|
|
|
numbered = []
|
|
for i, line in enumerate(lines, start=start):
|
|
numbered.append(f"{i:>4}\t{line.rstrip()}")
|
|
output = "\n".join(numbered)
|
|
original_len = len(output)
|
|
|
|
if original_len > 10_000:
|
|
output = (
|
|
output[:5_000]
|
|
+ f"\n\n... [{original_len - 10_000} chars truncated] ...\n\n"
|
|
+ output[-5_000:]
|
|
)
|
|
|
|
with self._print_lock:
|
|
desc = f"{len(lines)} lines"
|
|
if offset is not None or limit is not None:
|
|
end = start + len(lines) - 1
|
|
desc += f" (lines {start}-{end} of {total_lines})"
|
|
sys.stdout.write(f" {DIM}{desc}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
return call_id, output if output else "(empty file)"
|
|
|
|
def _exec_search(self, item: dict) -> tuple[str, str]:
|
|
"""Search file contents for a regex pattern using grep."""
|
|
call_id = item["call_id"]
|
|
pattern, path = item["pattern"], item["path"]
|
|
try:
|
|
result = subprocess.run(
|
|
[
|
|
"grep",
|
|
"-rn",
|
|
"-I",
|
|
"-E",
|
|
"-m",
|
|
"200", # max matches per file
|
|
"--color=never", # no ANSI codes in output
|
|
"--",
|
|
pattern,
|
|
path, # -- prevents pattern as flag
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=self.tool_timeout,
|
|
)
|
|
output = result.stdout.strip()
|
|
if result.returncode == 1:
|
|
output = "(no matches)"
|
|
elif result.returncode > 1:
|
|
output = (
|
|
result.stderr.strip() or f"grep error (exit {result.returncode})"
|
|
)
|
|
|
|
# Count matches BEFORE truncation
|
|
match_count = (
|
|
output.count("\n") + 1 if result.returncode == 0 and output else 0
|
|
)
|
|
|
|
original_len = len(output)
|
|
if original_len > 10_000:
|
|
output = (
|
|
output[:5_000]
|
|
+ f"\n\n... [{original_len - 10_000} chars truncated] ...\n\n"
|
|
+ output[-5_000:]
|
|
)
|
|
|
|
with self._print_lock:
|
|
desc = f"{match_count} matches" if match_count else "no matches"
|
|
if original_len > 500:
|
|
desc += f" ({original_len} chars)"
|
|
sys.stdout.write(f" {DIM}{desc}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
return call_id, output
|
|
|
|
except subprocess.TimeoutExpired:
|
|
msg = f"Search timed out after {self.tool_timeout}s"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
except Exception as e:
|
|
msg = f"Search error: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
|
|
# Tools the agent can auto-execute without user approval (read-only).
|
|
# bash, write_file, edit_file are excluded — the user approval prompt
|
|
# is the primary security boundary against prompt injection.
|
|
_AGENT_AUTO_TOOLS = {
|
|
"read_file",
|
|
"search",
|
|
"math",
|
|
"man",
|
|
"web_fetch",
|
|
"web_search",
|
|
}
|
|
|
|
def _run_agent(
|
|
self,
|
|
agent_messages: list[dict],
|
|
label: str = "agent",
|
|
tools: list[dict] | None = None,
|
|
auto_tools: set[str] | None = None,
|
|
reasoning_effort: str | None = None,
|
|
) -> str:
|
|
"""Run an autonomous agent loop.
|
|
|
|
Args:
|
|
agent_messages: Pre-built message list (system + developer + user).
|
|
label: Display prefix for progress lines ("agent" or "plan").
|
|
tools: Tool definitions to send to the API. Defaults to AGENT_TOOLS (read-only).
|
|
auto_tools: Set of tool names the agent may execute. Defaults to _AGENT_AUTO_TOOLS.
|
|
reasoning_effort: Override reasoning effort for this agent.
|
|
|
|
Returns:
|
|
Final content string from the agent.
|
|
"""
|
|
if tools is None:
|
|
tools = AGENT_TOOLS
|
|
if auto_tools is None:
|
|
auto_tools = self._AGENT_AUTO_TOOLS
|
|
max_tool_turns = 20
|
|
|
|
kwargs = dict(self._chat_template_kwargs_base)
|
|
if reasoning_effort:
|
|
kwargs["reasoning_effort"] = reasoning_effort
|
|
|
|
def _api_call(messages, _tools=tools):
|
|
return self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=messages,
|
|
tools=_tools,
|
|
max_completion_tokens=self.max_tokens,
|
|
temperature=self.temperature,
|
|
extra_body={
|
|
"chat_template_kwargs": kwargs,
|
|
},
|
|
)
|
|
|
|
for turn in range(max_tool_turns):
|
|
response = _api_call(agent_messages)
|
|
choice = response.choices[0]
|
|
assistant_msg = choice.message
|
|
|
|
# Build message dict for agent history
|
|
msg_dict = {
|
|
"role": "assistant",
|
|
"content": assistant_msg.content or "",
|
|
}
|
|
if assistant_msg.tool_calls:
|
|
msg_dict["tool_calls"] = [
|
|
{
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tc.function.name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
for tc in assistant_msg.tool_calls
|
|
]
|
|
agent_messages.append(msg_dict)
|
|
|
|
if not assistant_msg.tool_calls:
|
|
content = assistant_msg.content or "(no output)"
|
|
with self._print_lock:
|
|
sys.stdout.write(
|
|
f" {DIM}[{label} done] {len(content)} chars{RESET}\n"
|
|
)
|
|
sys.stdout.flush()
|
|
return content
|
|
|
|
# Execute tools sequentially (not parallel) to avoid
|
|
# concurrent _read_files mutation from worker threads.
|
|
tool_names = {t["function"]["name"] for t in tools}
|
|
for tc in assistant_msg.tool_calls:
|
|
tool_name = tc.function.name
|
|
|
|
# Guard 1: block recursive agent calls.
|
|
if tool_name in ("task", "plan"):
|
|
output = "Error: agents cannot spawn further agents"
|
|
# Guard 2: tool not in this agent's API tool list.
|
|
elif tool_name not in tool_names:
|
|
output = (
|
|
f"Error: tool '{tool_name}' is not available in "
|
|
f"agent mode. "
|
|
f"Available: {', '.join(sorted(tool_names))}"
|
|
)
|
|
else:
|
|
tc_dict = {
|
|
"id": tc.id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_name,
|
|
"arguments": tc.function.arguments,
|
|
},
|
|
}
|
|
prepared = self._prepare_tool(tc_dict)
|
|
|
|
with self._print_lock:
|
|
lbl = prepared.get("header", tool_name)
|
|
sys.stdout.write(
|
|
f" {DIM}[{label} turn {turn + 1}] {lbl}{RESET}\n"
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
if prepared.get("error"):
|
|
output = prepared["error"]
|
|
# Auto-execute tools in the auto_tools set.
|
|
elif tool_name in auto_tools:
|
|
_, output = prepared["execute"](prepared)
|
|
# Tools not in auto_tools require user approval.
|
|
elif "execute" in prepared:
|
|
self._display_and_approve([prepared])
|
|
if prepared.get("denied"):
|
|
output = prepared.get("denial_msg", "Denied by user")
|
|
else:
|
|
_, output = prepared["execute"](prepared)
|
|
else:
|
|
output = f"Unknown tool: {tool_name}"
|
|
|
|
agent_messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": tc.id,
|
|
"content": output,
|
|
}
|
|
)
|
|
|
|
# Exhausted tool turns — force a final synthesis response.
|
|
with self._print_lock:
|
|
sys.stdout.write(
|
|
f" {DIM}[{label}] turn limit reached, requesting synthesis...{RESET}\n"
|
|
)
|
|
sys.stdout.flush()
|
|
agent_messages.append(
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"You have reached the tool call limit. "
|
|
"Provide your complete response now using "
|
|
"the information you have gathered so far."
|
|
),
|
|
}
|
|
)
|
|
response = _api_call(agent_messages, _tools=[])
|
|
content = response.choices[0].message.content or "(no output)"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}[{label} done] {len(content)} chars{RESET}\n")
|
|
sys.stdout.flush()
|
|
return content
|
|
|
|
# Task agent read-only tools auto-execute without approval.
|
|
# bash, write_file, edit_file are in TASK_AGENT_TOOLS (the API tool list)
|
|
# but NOT here — they go through _display_and_approve for user approval.
|
|
_TASK_AUTO_TOOLS = {
|
|
"read_file",
|
|
"search",
|
|
"math",
|
|
"man",
|
|
"web_fetch",
|
|
"web_search",
|
|
}
|
|
|
|
def _exec_task(self, item: dict) -> tuple[str, str]:
|
|
"""Delegate to a general-purpose autonomous sub-agent."""
|
|
call_id, prompt = item["call_id"], item["prompt"]
|
|
task_instruction = {
|
|
"role": "developer",
|
|
"content": (
|
|
"# Task Agent\n\n"
|
|
"You are an autonomous task agent with full tool access. "
|
|
"You can use bash, read_file, write_file, edit_file, search, "
|
|
"math, web_fetch, and web_search.\n\n"
|
|
"1. **Follow through on actions:** Do not describe changes — "
|
|
"use the tools to make them. After read_file, call edit_file "
|
|
"or write_file.\n\n"
|
|
"2. **Tool selection:**\n"
|
|
" - Use read_file before edit_file on existing files.\n"
|
|
" - Use write_file for new files (not bash).\n"
|
|
" - Use bash for shell commands (git, python, tests).\n"
|
|
" - Use search to find code across files.\n\n"
|
|
"3. **Complete the task fully.** Do not ask follow-up "
|
|
"questions — execute the work as described in the prompt."
|
|
),
|
|
}
|
|
agent_messages = list(self._agent_system_messages) + [
|
|
task_instruction,
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
try:
|
|
return call_id, self._run_agent(
|
|
agent_messages,
|
|
label="task",
|
|
tools=TASK_AGENT_TOOLS,
|
|
auto_tools=self._TASK_AUTO_TOOLS,
|
|
)
|
|
except KeyboardInterrupt:
|
|
return call_id, "(task interrupted by user)"
|
|
except Exception as e:
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}[task error] {e}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, f"Task error: {e}"
|
|
|
|
def _exec_plan(self, item: dict) -> tuple[str, str]:
|
|
"""Run a planning agent and write the result to .plan.md."""
|
|
call_id, prompt = item["call_id"], item["prompt"]
|
|
plan_path = ".plan.md"
|
|
|
|
plan_instruction = {
|
|
"role": "developer",
|
|
"content": (
|
|
"# Planning Agent\n\n"
|
|
"1. **Exploration first:** Use read_file and search to understand the "
|
|
"codebase before writing any plan. Do not guess at file structure or "
|
|
"assume patterns — verify them.\n\n"
|
|
"2. **Output format** — use these exact sections:\n"
|
|
" - **## Goal**: What we're building and why (1-2 sentences).\n"
|
|
" - **## Current State**: Relevant existing code and patterns found "
|
|
"during exploration. Reference specific file paths and line numbers.\n"
|
|
" - **## Plan**: Numbered steps with specific files and functions "
|
|
"to create or modify. Each step should be actionable.\n"
|
|
" - **## Risks**: Edge cases, breaking changes, or unknowns.\n\n"
|
|
"3. **Specificity rules:**\n"
|
|
" - Reference file paths, line numbers, and function names — not vague areas.\n"
|
|
" - Each plan step should name the exact file(s) and describe the change.\n"
|
|
" - Identify dependencies between steps (what must happen first).\n\n"
|
|
"4. **Common mistakes to avoid:**\n"
|
|
" - Writing a plan without reading the codebase first\n"
|
|
' - Vague steps like "update the code" without naming files or functions\n'
|
|
" - Ignoring existing patterns or conventions in the codebase\n"
|
|
" - Planning changes to files that don't exist"
|
|
),
|
|
}
|
|
agent_messages = list(self._agent_system_messages) + [
|
|
plan_instruction,
|
|
{"role": "user", "content": prompt},
|
|
]
|
|
|
|
try:
|
|
content = self._run_agent(
|
|
agent_messages, label="plan", reasoning_effort="high"
|
|
)
|
|
except KeyboardInterrupt:
|
|
return call_id, "(plan interrupted by user)"
|
|
except Exception as e:
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}[plan error] {e}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, f"Plan error: {e}"
|
|
|
|
# Write to file separately — always return content even if write fails
|
|
try:
|
|
with open(plan_path, "w") as f:
|
|
f.write(content)
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}Plan written to {plan_path}{RESET}\n")
|
|
sys.stdout.flush()
|
|
except OSError as e:
|
|
with self._print_lock:
|
|
sys.stdout.write(
|
|
f" {DIM}[plan] could not write {plan_path}: {e}{RESET}\n"
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
return call_id, content
|
|
|
|
def _exec_remember(self, item: dict) -> tuple[str, str]:
|
|
"""Save a persistent memory."""
|
|
call_id, key, value = item["call_id"], item["key"], item["value"]
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
existing = conn.execute(
|
|
"SELECT value FROM memories WHERE key = ?", (key,)
|
|
).fetchone()
|
|
conn.execute(
|
|
"INSERT OR REPLACE INTO memories (key, value, created, updated) "
|
|
"VALUES (?, ?, COALESCE("
|
|
" (SELECT created FROM memories WHERE key = ?), "
|
|
" datetime('now')"
|
|
"), datetime('now'))",
|
|
(key, value, key),
|
|
)
|
|
conn.commit()
|
|
self._init_system_messages()
|
|
if existing:
|
|
msg = f"Updated memory: {key} = {value} (was: {existing[0]})"
|
|
else:
|
|
msg = f"Saved memory: {key} = {value}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}{msg}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
finally:
|
|
conn.close()
|
|
except Exception as e:
|
|
return call_id, f"Error: {e}"
|
|
|
|
def _exec_forget(self, item: dict) -> tuple[str, str]:
|
|
"""Remove a persistent memory by key."""
|
|
call_id, key = item["call_id"], item["key"]
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
cursor = conn.execute("DELETE FROM memories WHERE key = ?", (key,))
|
|
conn.commit()
|
|
if cursor.rowcount == 0:
|
|
msg = f"Error: memory '{key}' not found"
|
|
else:
|
|
self._init_system_messages()
|
|
msg = f"Forgot: {key}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}{msg}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
finally:
|
|
conn.close()
|
|
except Exception as e:
|
|
return call_id, f"Error: {e}"
|
|
|
|
def _exec_recall(self, item: dict) -> tuple[str, str]:
|
|
"""Search memories and conversation history."""
|
|
call_id = item["call_id"]
|
|
query, limit = item["query"], item["limit"]
|
|
parts: list[str] = []
|
|
|
|
# Memories: list all (no query) or search (with query)
|
|
try:
|
|
conn = _open_db()
|
|
try:
|
|
if not query:
|
|
rows = conn.execute(
|
|
"SELECT key, value FROM memories ORDER BY key"
|
|
).fetchall()
|
|
else:
|
|
terms = query.split()
|
|
clauses = []
|
|
params: list[str] = []
|
|
for t in terms:
|
|
escaped = _escape_like(t)
|
|
clauses.append(
|
|
"(key LIKE ? ESCAPE '\\' OR value LIKE ? ESCAPE '\\')"
|
|
)
|
|
params.extend([f"%{escaped}%", f"%{escaped}%"])
|
|
rows = conn.execute(
|
|
"SELECT key, value FROM memories WHERE "
|
|
+ " AND ".join(clauses)
|
|
+ " ORDER BY key",
|
|
params,
|
|
).fetchall()
|
|
if rows:
|
|
parts.append(
|
|
"Memories:\n" + "\n".join(f" {k}={v}" for k, v in rows)
|
|
)
|
|
elif not query:
|
|
parts.append("No memories stored.")
|
|
finally:
|
|
conn.close()
|
|
except Exception:
|
|
pass
|
|
|
|
# Conversations: only when a query is provided
|
|
if query:
|
|
conv_rows = _search_history(query, limit)
|
|
if conv_rows:
|
|
lines = []
|
|
for ts, sid, role, content, tool_name in conv_rows:
|
|
label = f"{role}({tool_name})" if tool_name else role
|
|
text = (content or "")[:500]
|
|
if content and len(content) > 500:
|
|
text += "..."
|
|
lines.append(f"[{ts} {sid}] {label}: {text}")
|
|
parts.append(
|
|
f"Conversations ({len(conv_rows)} matches):\n" + "\n".join(lines)
|
|
)
|
|
|
|
output = "\n\n".join(parts) if parts else f"No results for '{query}'."
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}{output}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, output
|
|
|
|
def _exec_write_file(self, item: dict) -> tuple[str, str]:
|
|
"""Write content to a file, creating parent directories as needed."""
|
|
call_id = item["call_id"]
|
|
path, content, resolved = item["path"], item["content"], item["resolved"]
|
|
try:
|
|
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
|
with open(path, "w") as f:
|
|
f.write(content)
|
|
self._read_files.add(resolved)
|
|
return call_id, f"Wrote {len(content)} chars to {path}"
|
|
except Exception as e:
|
|
return call_id, f"Error writing {path}: {e}"
|
|
|
|
def _exec_edit_file(self, item: dict) -> tuple[str, str]:
|
|
"""Replace an exact string in a file (re-reads to avoid TOCTOU).
|
|
|
|
When near_line is set, picks the occurrence nearest that line
|
|
instead of requiring uniqueness.
|
|
"""
|
|
call_id = item["call_id"]
|
|
path, old_string, new_string = (
|
|
item["path"],
|
|
item["old_string"],
|
|
item["new_string"],
|
|
)
|
|
near_line = item.get("near_line")
|
|
try:
|
|
with open(path, "r") as f:
|
|
content = f.read()
|
|
occurrences = _find_occurrences(content, old_string)
|
|
if len(occurrences) == 0:
|
|
return (
|
|
call_id,
|
|
f"Error: old_string no longer found in {path} (file changed)",
|
|
)
|
|
if len(occurrences) > 1 and near_line is None:
|
|
line_list = ", ".join(str(ln) for ln in occurrences)
|
|
return (
|
|
call_id,
|
|
f"Error: old_string found {len(occurrences)} times "
|
|
f"at lines {line_list} (file changed)",
|
|
)
|
|
if near_line is not None and len(occurrences) > 1:
|
|
# Replace only the occurrence nearest to near_line
|
|
idx = _pick_nearest(content, old_string, near_line)
|
|
content = content[:idx] + new_string + content[idx + len(old_string) :]
|
|
else:
|
|
content = content.replace(old_string, new_string, 1)
|
|
with open(path, "w") as f:
|
|
f.write(content)
|
|
return call_id, f"Edited {path}: replaced 1 occurrence"
|
|
except Exception as e:
|
|
return call_id, f"Error writing {path}: {e}"
|
|
|
|
def _exec_math(self, item: dict) -> tuple[str, str]:
|
|
"""Execute Python code in sandboxed subprocess."""
|
|
call_id, code = item["call_id"], item["code"]
|
|
output, is_error = _execute_math_sandboxed(code, timeout=self.tool_timeout)
|
|
original_len = len(output)
|
|
|
|
if original_len > 10_000:
|
|
output = (
|
|
output[:5_000]
|
|
+ f"\n\n... [{original_len - 10_000} chars truncated] ...\n\n"
|
|
+ output[-5_000:]
|
|
)
|
|
|
|
with self._print_lock:
|
|
preview = output[:500]
|
|
if original_len > 500:
|
|
preview += f"\n ... ({original_len} chars total)"
|
|
if preview:
|
|
indented = textwrap.indent(preview, " ")
|
|
color = RED if is_error else DIM
|
|
sys.stdout.write(f"{color}{indented}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
if is_error:
|
|
return call_id, f"Error:\n{output}"
|
|
return call_id, output if output else "(no output)"
|
|
|
|
def _exec_man(self, item: dict) -> tuple[str, str]:
|
|
"""Look up a man or info page."""
|
|
call_id = item["call_id"]
|
|
page = item["page"]
|
|
section = item.get("section", "")
|
|
|
|
# Try man first, fall back to info
|
|
cmd = ["man"]
|
|
if section:
|
|
cmd.append(section)
|
|
cmd.append(page)
|
|
|
|
try:
|
|
result = subprocess.run(
|
|
cmd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=10,
|
|
env={**os.environ, "MANWIDTH": "80", "MAN_KEEP_FORMATTING": "0"},
|
|
)
|
|
if result.returncode == 0 and result.stdout.strip():
|
|
# Strip formatting: backspace overstrikes and ANSI escapes
|
|
text = re.sub(r".\x08", "", result.stdout)
|
|
text = re.sub(r"\x1b\[[0-9;]*m", "", text)
|
|
else:
|
|
# Fall back to info
|
|
result = subprocess.run(
|
|
["info", page],
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=10,
|
|
)
|
|
if result.returncode == 0 and result.stdout.strip():
|
|
text = result.stdout
|
|
else:
|
|
msg = f"No man or info page found for '{page}'"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}{msg}{RESET}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
except FileNotFoundError:
|
|
return call_id, "Error: man command not available"
|
|
except subprocess.TimeoutExpired:
|
|
return call_id, "Error: man page lookup timed out"
|
|
|
|
# Truncate very long pages
|
|
if len(text) > 30_000:
|
|
total = len(text)
|
|
text = (
|
|
text[:30_000]
|
|
+ f"\n\n... [truncated: showing first 30000 of {total} chars. "
|
|
f"Use a specific section number to narrow results.]"
|
|
)
|
|
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {DIM}{len(text)} chars{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
return call_id, text
|
|
|
|
def _exec_web_fetch(self, item: dict) -> tuple[str, str]:
|
|
"""Fetch a URL, then summarize/extract using an API call."""
|
|
call_id, url = item["call_id"], item["url"]
|
|
question = item.get("question", "Summarize the key content of this page.")
|
|
|
|
# Phase 1: fetch the URL
|
|
try:
|
|
req = Request(url, headers={"User-Agent": "chat.py/1.0"})
|
|
with urlopen(req, timeout=self.tool_timeout) as resp:
|
|
ct = resp.headers.get_content_type() or ""
|
|
charset = resp.headers.get_content_charset() or "utf-8"
|
|
raw = resp.read(10 * 1024 * 1024) # 10 MB cap
|
|
text = raw.decode(charset, errors="replace")
|
|
|
|
if "html" in ct:
|
|
text = _strip_html(text)
|
|
|
|
except (URLError, TimeoutError) as e:
|
|
msg = f"Fetch failed: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
except ValueError as e:
|
|
msg = f"Invalid URL: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
except Exception as e:
|
|
msg = f"Error fetching URL: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
|
|
if not text.strip():
|
|
return call_id, "(empty response from URL)"
|
|
|
|
original_len = len(text)
|
|
with self._print_lock:
|
|
sys.stdout.write(
|
|
f" {DIM}fetched {original_len} chars, extracting...{RESET}\n"
|
|
)
|
|
sys.stdout.flush()
|
|
|
|
# Phase 2: truncate for summarization context
|
|
max_content = 50_000
|
|
if len(text) > max_content:
|
|
text = (
|
|
text[: max_content // 2]
|
|
+ f"\n\n... [{len(text) - max_content} chars omitted] ...\n\n"
|
|
+ text[-(max_content // 2) :]
|
|
)
|
|
|
|
# Phase 3: summarization API call
|
|
try:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"You are a web content extraction assistant. "
|
|
"Answer the user's question using ONLY the "
|
|
"provided page content. Be concise and factual. "
|
|
"If the content doesn't contain the answer, say so."
|
|
),
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
f"Page URL: {url}\n"
|
|
f"Page content ({original_len} chars):\n\n"
|
|
f"{text}\n\n---\n"
|
|
f"Question: {question}"
|
|
),
|
|
},
|
|
],
|
|
max_completion_tokens=2000,
|
|
temperature=0.2,
|
|
)
|
|
answer = response.choices[0].message.content or "(no answer)"
|
|
except Exception as e:
|
|
answer = (
|
|
f"Extraction failed (page was fetched but summarization errored): {e}"
|
|
)
|
|
|
|
with self._print_lock:
|
|
preview = answer[:300]
|
|
if len(answer) > 300:
|
|
preview += "..."
|
|
indented = textwrap.indent(preview, " ")
|
|
sys.stdout.write(f"{DIM}{indented}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
return call_id, answer
|
|
|
|
def _exec_web_search(self, item: dict) -> tuple[str, str]:
|
|
"""Search the web via Tavily API."""
|
|
call_id = item["call_id"]
|
|
query = item["query"]
|
|
max_results = item.get("max_results", 5)
|
|
topic = item.get("topic", "general")
|
|
api_key = _get_tavily_key()
|
|
|
|
payload = json.dumps(
|
|
{
|
|
"query": query,
|
|
"max_results": max_results,
|
|
"topic": topic,
|
|
"include_answer": True,
|
|
}
|
|
).encode()
|
|
req = Request(
|
|
"https://api.tavily.com/search",
|
|
data=payload,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {api_key}",
|
|
},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urlopen(req, timeout=self.tool_timeout) as resp:
|
|
data = json.loads(resp.read().decode())
|
|
except Exception as e:
|
|
msg = f"Tavily search failed: {e}"
|
|
with self._print_lock:
|
|
sys.stdout.write(f" {red(msg)}\n")
|
|
sys.stdout.flush()
|
|
return call_id, msg
|
|
|
|
parts: list[str] = []
|
|
answer = (data.get("answer") or "").strip()
|
|
if answer:
|
|
parts.append(f"Answer: {answer}")
|
|
|
|
results = data.get("results") or []
|
|
if results:
|
|
lines = []
|
|
for i, r in enumerate(results, 1):
|
|
title = r.get("title", "")
|
|
url = r.get("url", "")
|
|
content = (r.get("content") or "")[:500]
|
|
lines.append(f"{i}. [{title}]({url})\n {content}")
|
|
parts.append("\n".join(lines))
|
|
|
|
output = "\n\n".join(parts) if parts else f"No results for '{query}'."
|
|
|
|
with self._print_lock:
|
|
preview = output[:400]
|
|
if len(output) > 400:
|
|
preview += "..."
|
|
indented = textwrap.indent(preview, " ")
|
|
sys.stdout.write(f"{DIM}{indented}{RESET}\n")
|
|
sys.stdout.flush()
|
|
|
|
return call_id, output
|
|
|
|
def handle_command(self, cmd_line: str) -> bool:
|
|
"""Handle slash commands. Returns True if should exit."""
|
|
parts = cmd_line.strip().split(None, 1)
|
|
cmd = parts[0].lower()
|
|
arg = parts[1] if len(parts) > 1 else ""
|
|
|
|
if cmd in ("/exit", "/quit", "/q"):
|
|
return True
|
|
|
|
elif cmd == "/persona":
|
|
if not arg:
|
|
if self.persona:
|
|
print(f"Current persona: {cyan(self.persona)}")
|
|
else:
|
|
print("No persona set. Usage: /persona <name>")
|
|
else:
|
|
self.persona = arg.strip()
|
|
self._init_system_messages()
|
|
print(f"Switched persona to {cyan(self.persona)}")
|
|
|
|
elif cmd == "/instructions":
|
|
if not arg:
|
|
if self.instructions:
|
|
print(f"Current instructions: {self.instructions[:100]}...")
|
|
else:
|
|
print("No instructions set. Usage: /instructions <text>")
|
|
else:
|
|
self.instructions = arg.strip()
|
|
self._init_system_messages()
|
|
print(f"Instructions updated.")
|
|
|
|
elif cmd in ("/clear", "/new"):
|
|
self.messages.clear()
|
|
self._read_files.clear()
|
|
self.md = MarkdownRenderer()
|
|
self._last_usage = None
|
|
self._msg_tokens = []
|
|
self._chars_per_token = 4.0
|
|
try:
|
|
db = _open_db()
|
|
db.execute(
|
|
"DELETE FROM conversations WHERE session_id = ?",
|
|
(self._session_id,),
|
|
)
|
|
db.commit()
|
|
db.close()
|
|
except Exception:
|
|
pass
|
|
print("History cleared.")
|
|
|
|
elif cmd == "/history":
|
|
query = arg.strip() if arg else None
|
|
if query:
|
|
rows = _search_history(query, limit=20)
|
|
if not rows:
|
|
print(f"No results for {query!r}")
|
|
else:
|
|
print(f"Found {len(rows)} result(s) for {query!r}:\n")
|
|
for ts, sid, role, content, tool_name in rows:
|
|
label = tool_name if tool_name else role
|
|
text = (content or "")[:200]
|
|
print(f" {dim(ts)} {dim(sid)} {bold(label)}: {text}")
|
|
else:
|
|
# Show recent conversations (last 20 messages)
|
|
rows = _search_history_recent(limit=20)
|
|
if not rows:
|
|
print("No conversation history yet.")
|
|
else:
|
|
print("Recent history:\n")
|
|
for ts, sid, role, content, tool_name in rows:
|
|
label = tool_name if tool_name else role
|
|
text = (content or "")[:200]
|
|
print(f" {dim(ts)} {dim(sid)} {bold(label)}: {text}")
|
|
|
|
elif cmd == "/model":
|
|
print(f"Model: {cyan(self.model)}")
|
|
|
|
elif cmd == "/raw":
|
|
self.show_reasoning = not self.show_reasoning
|
|
state = "on" if self.show_reasoning else "off"
|
|
print(f"Reasoning display: {bold(state)}")
|
|
|
|
elif cmd == "/reason":
|
|
valid = ("low", "medium", "high")
|
|
aliases = {"med": "medium", "lo": "low", "hi": "high"}
|
|
if not arg:
|
|
print(f"Reasoning effort: {cyan(self.reasoning_effort)}")
|
|
else:
|
|
value = aliases.get(arg.lower(), arg.lower())
|
|
if value in valid:
|
|
self.reasoning_effort = value
|
|
self._init_system_messages()
|
|
print(f"Reasoning effort set to {cyan(self.reasoning_effort)}")
|
|
else:
|
|
print(f"Invalid. Choose from: {', '.join(valid)}")
|
|
|
|
elif cmd == "/compact":
|
|
self._compact_messages()
|
|
|
|
elif cmd == "/creative":
|
|
self.creative_mode = not self.creative_mode
|
|
self._init_system_messages()
|
|
# Clear history when toggling ON if it contains tool messages,
|
|
# because the API rejects tool-call history without tool definitions
|
|
if self.creative_mode and any(
|
|
m.get("tool_calls") or m.get("role") == "tool" for m in self.messages
|
|
):
|
|
self.messages.clear()
|
|
self._read_files.clear()
|
|
self._msg_tokens.clear()
|
|
print(
|
|
dim(
|
|
"[history cleared — creative mode is incompatible with tool history]"
|
|
)
|
|
)
|
|
state = "on" if self.creative_mode else "off"
|
|
print(
|
|
f"Creative mode: {bold(state)} (tools {'disabled' if self.creative_mode else 'enabled'})"
|
|
)
|
|
|
|
elif cmd == "/debug":
|
|
self.debug = not self.debug
|
|
state = "on" if self.debug else "off"
|
|
print(f"Debug mode: {bold(state)} (prints raw SSE deltas)")
|
|
|
|
elif cmd == "/help":
|
|
print(
|
|
f"""
|
|
{bold("Slash Commands:")}
|
|
/persona <name> Set persona (system message)
|
|
/instructions <text> Set developer instructions
|
|
/clear, /new Clear conversation history
|
|
/history [query] Search conversation history (or show recent)
|
|
/model Show current model
|
|
/raw Toggle reasoning content display
|
|
/reason [low|med|high] Set/show reasoning effort
|
|
/compact Compact conversation (summarize old messages)
|
|
/creative Toggle creative writing mode (no tools)
|
|
/debug Toggle raw SSE delta logging
|
|
/help Show this help
|
|
/exit Exit (also: Ctrl+D)
|
|
""".strip()
|
|
)
|
|
|
|
else:
|
|
print(f"Unknown command: {cmd}. Type /help for available commands.")
|
|
|
|
return False
|
|
|
|
|
|
# ─── Model auto-detection ─────────────────────────────────────────────────
|
|
|
|
|
|
def detect_model(client: OpenAI) -> str:
|
|
"""Auto-detect the model from vLLM's /v1/models endpoint."""
|
|
try:
|
|
models = client.models.list()
|
|
model_ids = [m.id for m in models.data]
|
|
if not model_ids:
|
|
print(red("No models found at server. Use --model to specify."))
|
|
sys.exit(1)
|
|
if len(model_ids) == 1:
|
|
return model_ids[0]
|
|
# Multiple models — pick first, but inform user
|
|
print(f"Available models: {', '.join(model_ids)}")
|
|
print(f"Using: {bold(model_ids[0])} (override with --model)")
|
|
return model_ids[0]
|
|
except Exception as e:
|
|
print(red(f"Could not connect to server: {e}"))
|
|
print("Is vLLM running? Start it or use --base-url to point elsewhere.")
|
|
sys.exit(1)
|
|
|
|
|
|
# ─── Main ──────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Interactive CLI for vLLM models with tool calling.",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog=textwrap.dedent("""\
|
|
Examples:
|
|
python3 chat.py # auto-detect model
|
|
python3 chat.py --persona lawful_evil # with persona
|
|
python3 chat.py --model kappa_20b_131k # explicit model
|
|
python3 chat.py --temperature 0.7 # lower temperature
|
|
"""),
|
|
)
|
|
parser.add_argument(
|
|
"--base-url",
|
|
default="http://localhost:8000/v1",
|
|
help="vLLM API base URL (default: http://localhost:8000/v1)",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
default=None,
|
|
help="Model name (default: auto-detect from server)",
|
|
)
|
|
parser.add_argument(
|
|
"--persona",
|
|
default=None,
|
|
help="Persona name injected as system message",
|
|
)
|
|
parser.add_argument(
|
|
"--instructions",
|
|
default=None,
|
|
help="Developer instructions injected as developer message",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.5,
|
|
help="Sampling temperature (default: 0.5)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
type=int,
|
|
default=32768,
|
|
help="Max completion tokens (default: 32768)",
|
|
)
|
|
parser.add_argument(
|
|
"--tool-timeout",
|
|
type=int,
|
|
default=30,
|
|
help="Bash command timeout in seconds (default: 30)",
|
|
)
|
|
parser.add_argument(
|
|
"--reasoning-effort",
|
|
default="medium",
|
|
choices=["low", "medium", "high"],
|
|
help="Reasoning effort level (default: medium)",
|
|
)
|
|
parser.add_argument(
|
|
"--context-window",
|
|
type=int,
|
|
default=131072,
|
|
help="Context window size in tokens (default: 131072)",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-permissions",
|
|
action="store_true",
|
|
help="Auto-approve all tool calls (no confirmation prompts)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Set up readline
|
|
setup_readline()
|
|
|
|
# Create client
|
|
client = OpenAI(
|
|
base_url=args.base_url,
|
|
api_key=os.environ.get("OPENAI_API_KEY", "dummy"),
|
|
)
|
|
|
|
# Detect or use provided model
|
|
if args.model:
|
|
model = args.model
|
|
else:
|
|
model = detect_model(client)
|
|
|
|
# Create session
|
|
session = ChatSession(
|
|
client=client,
|
|
model=model,
|
|
persona=args.persona,
|
|
instructions=args.instructions,
|
|
temperature=args.temperature,
|
|
max_tokens=args.max_tokens,
|
|
tool_timeout=args.tool_timeout,
|
|
reasoning_effort=args.reasoning_effort,
|
|
context_window=args.context_window,
|
|
)
|
|
if args.skip_permissions:
|
|
session.auto_approve = True
|
|
|
|
# Print banner
|
|
print(f"\n{bold('Chat')} with {cyan(model)}")
|
|
if args.persona:
|
|
print(f"Persona: {cyan(args.persona)}")
|
|
print(f"Type /help for commands, /exit or Ctrl+D to quit.\n")
|
|
|
|
# Prompt string — use a short display name
|
|
display_name = model.split("/")[-1] # strip path prefixes if any
|
|
if len(display_name) > 30:
|
|
display_name = display_name[:27] + "..."
|
|
|
|
# Main loop
|
|
while True:
|
|
try:
|
|
user_input = input(f"\001{BOLD}\002[{display_name}]\001{RESET}\002 > ")
|
|
except (EOFError, KeyboardInterrupt):
|
|
print()
|
|
break
|
|
|
|
user_input = user_input.strip()
|
|
if not user_input:
|
|
continue
|
|
|
|
if user_input.startswith("/"):
|
|
should_exit = session.handle_command(user_input)
|
|
if should_exit:
|
|
break
|
|
else:
|
|
try:
|
|
session.send(user_input)
|
|
except KeyboardInterrupt:
|
|
print(f"\n{yellow('Interrupted.')}")
|
|
except Exception as e:
|
|
print(f"\n{red(f'Error: {e}')}")
|
|
|
|
print("Goodbye.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|