387 lines
13 KiB
Python
387 lines
13 KiB
Python
"""
|
|
Minimal reproducible inference script for sweep-next-edit-v2-7B.
|
|
|
|
This model predicts the next edit a developer will make given:
|
|
- the current file contents
|
|
- recent changes (diffs)
|
|
- the cursor position
|
|
- (optional) retrieval chunks from other files
|
|
|
|
Usage:
|
|
python inference.py
|
|
|
|
Requires: transformers, torch, accelerate
|
|
pip install transformers torch accelerate
|
|
"""
|
|
|
|
import torch
|
|
from dataclasses import dataclass
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
MODEL_ID = "sweepai/sweep-next-edit-v2-7B"
|
|
|
|
# --- Prompt template (from sweepai/autocomplete/next_edit_autocomplete.py) ---
|
|
PROMPT_TEMPLATE = """<|file_sep|>{file_path}
|
|
{initial_file}{retrieval_results}
|
|
{recent_changes}
|
|
<|file_sep|>original/{file_path}:{start_line}:{end_line}
|
|
{prev_section}
|
|
<|file_sep|>current/{file_path}:{start_line}:{end_line}
|
|
{code_block}
|
|
<|file_sep|>updated/{file_path}:{start_line}:{end_line}
|
|
{prefill}"""
|
|
|
|
DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line}
|
|
original:
|
|
{old_code}
|
|
updated:
|
|
{new_code}"""
|
|
|
|
STOP_TOKENS = ["<|endoftext|>", "<|file_sep|>"]
|
|
MAX_NEW_TOKENS = 1024
|
|
|
|
|
|
@dataclass
|
|
class FileChunk:
|
|
"""A chunk of code from another file, used for cross-file context (retrieval)."""
|
|
file_path: str
|
|
content: str
|
|
|
|
def to_string(self) -> str:
|
|
return f"<|file_sep|>{self.file_path}\n{self.content}\n"
|
|
|
|
|
|
def compute_prefill(
|
|
code_block: str,
|
|
relative_cursor: int,
|
|
changes_above_cursor: bool = False,
|
|
) -> str:
|
|
"""
|
|
Compute the prefill string — the portion of the updated code block that we
|
|
feed to the model so it only has to generate starting from the edit point.
|
|
|
|
The model's job is to produce the full "updated" code block. But most of it
|
|
is unchanged — only a small region near the cursor is different. So we
|
|
"prefill" the output with the unchanged prefix, and the model just continues
|
|
from there.
|
|
|
|
Two strategies depending on what the user just did:
|
|
|
|
changes_above_cursor=True (last action was an insertion):
|
|
The user just inserted text above the cursor. The lines above the cursor
|
|
may have shifted, so we can't trust them as a prefill — the model might
|
|
need to edit them. We only prefill the very first line of the code block
|
|
(plus any blank lines after it), giving the model freedom to rewrite
|
|
everything from line 2 onward.
|
|
|
|
Example: code_block is 11 lines, cursor on line 10.
|
|
Prefill = line 1 + any trailing blank lines = " if n <= 0:\n"
|
|
Model generates lines 2-11.
|
|
|
|
changes_above_cursor=False (last action was NOT an insertion):
|
|
The user did something else (navigation, deletion, etc). The lines above
|
|
the cursor are likely stable, so we prefill up to the cursor line. This
|
|
constrains the model to only edit at/below the cursor.
|
|
|
|
We prefill everything before the cursor's line (up to the last newline
|
|
before cursor position), so the model starts generating from the cursor
|
|
line itself.
|
|
|
|
Example: code_block is 11 lines, cursor on line 10 col 0.
|
|
Prefill = lines 1-9 (everything up to the last \\n before cursor).
|
|
Model generates lines 10-11.
|
|
"""
|
|
if changes_above_cursor:
|
|
# --- Insertion mode: only prefill first line + trailing newlines ---
|
|
prefill = code_block[:relative_cursor]
|
|
prefilled_lines = prefill.splitlines(True)
|
|
|
|
NUM_LINES_ABOVE = 1
|
|
before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE])
|
|
after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:])
|
|
|
|
# Append consecutive newlines (blank lines) but stop at first real char.
|
|
# This preserves blank-line structure without constraining the model
|
|
# to keep the original code on those lines.
|
|
for char in after_split:
|
|
if char == "\n":
|
|
before_split += "\n"
|
|
else:
|
|
break
|
|
|
|
return before_split
|
|
else:
|
|
# --- Default mode: prefill up to the cursor line ---
|
|
prefix_before_cursor = code_block[:relative_cursor]
|
|
if "\n" not in prefix_before_cursor:
|
|
# Cursor is on the first line — no prefill possible
|
|
return ""
|
|
prefill_end = prefix_before_cursor.rfind("\n") + 1
|
|
return code_block[:prefill_end]
|
|
|
|
|
|
def is_pure_insertion_above_cursor(
|
|
code_block: str, completion: str, relative_cursor: int
|
|
) -> bool:
|
|
"""
|
|
Reject completions that only insert new lines above the cursor without
|
|
actually editing the cursor line. These are low-value predictions —
|
|
the model is just guessing what new code to add rather than fixing
|
|
an existing reference.
|
|
"""
|
|
current_line_index = len(code_block[:relative_cursor].splitlines(True))
|
|
code_block_lines = code_block.splitlines(True)
|
|
cursor_line = code_block_lines[current_line_index - 1]
|
|
|
|
if code_block.strip() == completion.strip():
|
|
return False
|
|
if not cursor_line.strip():
|
|
return False
|
|
|
|
prefix_lines = code_block_lines[:current_line_index - 1]
|
|
prefix = "".join(prefix_lines)
|
|
suffix_lines = code_block_lines[current_line_index:]
|
|
suffix = "".join(suffix_lines)
|
|
|
|
# If completion = prefix + NEW STUFF + cursor_line + suffix, it's a pure
|
|
# insertion above cursor (nothing at/below cursor changed).
|
|
if completion.startswith(prefix) and completion.endswith(cursor_line + suffix):
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
def build_prompt(
|
|
file_path: str,
|
|
file_contents: str,
|
|
cursor_position: int,
|
|
recent_changes: str = "",
|
|
retrieval_chunks: list[FileChunk] | None = None,
|
|
file_chunks: list[FileChunk] | None = None,
|
|
changes_above_cursor: bool = False,
|
|
num_lines_before: int = 10,
|
|
num_lines_after: int = 10,
|
|
) -> tuple[str, str, int, int]:
|
|
"""
|
|
Build the model prompt from file contents and cursor position.
|
|
|
|
Args:
|
|
file_path: Path of the file being edited.
|
|
file_contents: Full contents of the file after the user's latest edit.
|
|
cursor_position: Character offset of the cursor in file_contents.
|
|
recent_changes: Formatted diff string of recent changes (use DIFF_FORMAT).
|
|
retrieval_chunks: Cross-file context chunks (e.g. related functions from
|
|
other files). Placed AFTER recent_changes in the prompt for optimal
|
|
KV cache reuse.
|
|
file_chunks: Additional file context chunks. Prepended to the prompt.
|
|
changes_above_cursor: Whether the user's last action was an insertion.
|
|
Controls the prefill strategy (see compute_prefill).
|
|
num_lines_before: Lines of code to include before cursor in the block.
|
|
num_lines_after: Lines of code to include after cursor in the block.
|
|
|
|
Returns:
|
|
(formatted_prompt, code_block, block_start_index, relative_cursor)
|
|
"""
|
|
lines = file_contents.splitlines(True)
|
|
|
|
# Find cursor line
|
|
pos = 0
|
|
cursor_line = 0
|
|
for i, line in enumerate(lines):
|
|
if pos + len(line) > cursor_position:
|
|
cursor_line = i
|
|
break
|
|
pos += len(line)
|
|
else:
|
|
cursor_line = len(lines) - 1
|
|
|
|
# Extract code block around cursor
|
|
block_start = max(0, cursor_line - num_lines_before)
|
|
block_end = min(len(lines), cursor_line + num_lines_after + 1)
|
|
code_block = "".join(lines[block_start:block_end])
|
|
block_start_index = sum(len(l) for l in lines[:block_start])
|
|
|
|
# Relative cursor position within code block
|
|
relative_cursor = cursor_position - block_start_index
|
|
|
|
# Insert <|cursor|> marker into the "current" version
|
|
code_block_with_cursor = (
|
|
code_block[:relative_cursor]
|
|
+ "<|cursor|>"
|
|
+ code_block[relative_cursor:]
|
|
)
|
|
|
|
# prev_section = code_block without cursor (the "original" version)
|
|
prev_section = code_block
|
|
|
|
# Compute prefill based on whether last action was an insertion
|
|
prefill = compute_prefill(code_block, relative_cursor, changes_above_cursor)
|
|
|
|
# initial_file: broad context around cursor from the file (up to ~300 lines)
|
|
context_start = max(0, cursor_line - 150)
|
|
context_end = min(len(lines), cursor_line + 150)
|
|
initial_file = "".join(lines[context_start:context_end])
|
|
|
|
# Format retrieval results (cross-file context)
|
|
retrieval_results = ""
|
|
if retrieval_chunks:
|
|
retrieval_results = "".join(
|
|
f"\n{chunk.to_string()}" for chunk in retrieval_chunks
|
|
)
|
|
|
|
start_line = block_start + 1
|
|
end_line = block_end
|
|
|
|
formatted = PROMPT_TEMPLATE.format(
|
|
file_path=file_path,
|
|
initial_file=initial_file,
|
|
retrieval_results=retrieval_results,
|
|
recent_changes=recent_changes,
|
|
prev_section=prev_section,
|
|
code_block=code_block_with_cursor,
|
|
start_line=start_line,
|
|
end_line=end_line,
|
|
prefill=prefill,
|
|
)
|
|
|
|
# Prepend file chunks (other open files for context)
|
|
if file_chunks:
|
|
formatted = "".join(c.to_string() for c in file_chunks) + formatted
|
|
|
|
return formatted, code_block, block_start_index, relative_cursor
|
|
|
|
|
|
def generate(model, tokenizer, prompt: str, device: str = "cuda") -> str:
|
|
"""Run inference and return the completion (the predicted updated code block)."""
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
stop_token_ids = [
|
|
tokenizer.convert_tokens_to_ids(t)
|
|
for t in STOP_TOKENS
|
|
if t in tokenizer.get_vocab()
|
|
]
|
|
eos_ids = list(set(stop_token_ids + [tokenizer.eos_token_id]))
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
**inputs,
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
do_sample=False, # greedy (temperature=0)
|
|
eos_token_id=eos_ids,
|
|
pad_token_id=tokenizer.eos_token_id,
|
|
)
|
|
|
|
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
|
completion = tokenizer.decode(new_tokens, skip_special_tokens=False)
|
|
|
|
# Strip stop tokens from output
|
|
for stop in STOP_TOKENS:
|
|
if stop in completion:
|
|
completion = completion[: completion.index(stop)]
|
|
|
|
return completion
|
|
|
|
|
|
def main():
|
|
# --- Example: predict the next edit ---
|
|
file_path = "example.py"
|
|
file_contents = """\
|
|
def fibonacci(n):
|
|
if n <= 0:
|
|
return 0
|
|
elif n == 1:
|
|
return 1
|
|
else:
|
|
return fibonacci(n - 1) + fibonacci(n - 2)
|
|
|
|
|
|
def main():
|
|
for i in range(10):
|
|
print(fibonacci(i))
|
|
"""
|
|
|
|
# Simulate: user just renamed fibonacci -> fib on line 7,
|
|
# cursor is now on line 12 (the call site that still says fibonacci).
|
|
edited_contents = file_contents.replace(
|
|
"return fibonacci(n - 1) + fibonacci(n - 2)",
|
|
"return fib(n - 1) + fib(n - 2)",
|
|
).replace(
|
|
"def fibonacci(n):",
|
|
"def fib(n):",
|
|
)
|
|
|
|
# Cursor is on the print line that still references "fibonacci"
|
|
cursor_line_text = " print(fibonacci(i))"
|
|
cursor_position = edited_contents.index(cursor_line_text)
|
|
|
|
# Recent change as a diff
|
|
recent_changes = DIFF_FORMAT.format(
|
|
file_path=file_path,
|
|
start_line=1,
|
|
end_line=7,
|
|
old_code="def fibonacci(n):\n return fibonacci(n - 1) + fibonacci(n - 2)",
|
|
new_code="def fib(n):\n return fib(n - 1) + fib(n - 2)",
|
|
)
|
|
|
|
# Example retrieval chunk: a related function from another file
|
|
retrieval_chunks = [
|
|
FileChunk(
|
|
file_path="utils.py",
|
|
content="def fib_memo(n, memo={}):\n if n in memo:\n return memo[n]\n memo[n] = fib_memo(n-1) + fib_memo(n-2)\n return memo[n]",
|
|
)
|
|
]
|
|
|
|
# The rename was NOT an insertion, so changes_above_cursor=False.
|
|
# This means the prefill will include everything up to the cursor line,
|
|
# constraining the model to only edit at/below the cursor.
|
|
prompt, code_block, block_start, relative_cursor = build_prompt(
|
|
file_path=file_path,
|
|
file_contents=edited_contents,
|
|
cursor_position=cursor_position,
|
|
recent_changes=recent_changes,
|
|
retrieval_chunks=retrieval_chunks,
|
|
changes_above_cursor=False,
|
|
)
|
|
|
|
print("=" * 60)
|
|
print("PROMPT")
|
|
print("=" * 60)
|
|
print(prompt)
|
|
print()
|
|
|
|
# --- Load model and run inference ---
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
|
print(f"Loading model {MODEL_ID} on {device}...")
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
MODEL_ID,
|
|
dtype=torch.bfloat16,
|
|
device_map=device,
|
|
trust_remote_code=True,
|
|
)
|
|
|
|
print("Running inference...")
|
|
completion = generate(model, tokenizer, prompt, device=device)
|
|
|
|
# Check for pure insertion above cursor (low-value prediction)
|
|
if is_pure_insertion_above_cursor(code_block, completion, relative_cursor):
|
|
print("Rejected: model only inserted above cursor without editing cursor line.")
|
|
return
|
|
|
|
print("=" * 60)
|
|
print("MODEL OUTPUT (predicted updated code block)")
|
|
print("=" * 60)
|
|
print(completion)
|
|
print()
|
|
|
|
# Show the diff
|
|
print("=" * 60)
|
|
print("DIFF")
|
|
print("=" * 60)
|
|
print(f"Original code block:\n{code_block}")
|
|
print(f"Updated code block:\n{completion}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|