""" Minimal reproducible inference script for sweep-next-edit-v2-7B. This model predicts the next edit a developer will make given: - the current file contents - recent changes (diffs) - the cursor position - (optional) retrieval chunks from other files Usage: python inference.py Requires: transformers, torch, accelerate pip install transformers torch accelerate """ import torch from dataclasses import dataclass from transformers import AutoModelForCausalLM, AutoTokenizer MODEL_ID = "sweepai/sweep-next-edit-v2-7B" # --- Prompt template (from sweepai/autocomplete/next_edit_autocomplete.py) --- PROMPT_TEMPLATE = """<|file_sep|>{file_path} {initial_file}{retrieval_results} {recent_changes} <|file_sep|>original/{file_path}:{start_line}:{end_line} {prev_section} <|file_sep|>current/{file_path}:{start_line}:{end_line} {code_block} <|file_sep|>updated/{file_path}:{start_line}:{end_line} {prefill}""" DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line} original: {old_code} updated: {new_code}""" STOP_TOKENS = ["<|endoftext|>", "<|file_sep|>"] MAX_NEW_TOKENS = 1024 @dataclass class FileChunk: """A chunk of code from another file, used for cross-file context (retrieval).""" file_path: str content: str def to_string(self) -> str: return f"<|file_sep|>{self.file_path}\n{self.content}\n" def compute_prefill( code_block: str, relative_cursor: int, changes_above_cursor: bool = False, ) -> str: """ Compute the prefill string — the portion of the updated code block that we feed to the model so it only has to generate starting from the edit point. The model's job is to produce the full "updated" code block. But most of it is unchanged — only a small region near the cursor is different. So we "prefill" the output with the unchanged prefix, and the model just continues from there. Two strategies depending on what the user just did: changes_above_cursor=True (last action was an insertion): The user just inserted text above the cursor. The lines above the cursor may have shifted, so we can't trust them as a prefill — the model might need to edit them. We only prefill the very first line of the code block (plus any blank lines after it), giving the model freedom to rewrite everything from line 2 onward. Example: code_block is 11 lines, cursor on line 10. Prefill = line 1 + any trailing blank lines = " if n <= 0:\n" Model generates lines 2-11. changes_above_cursor=False (last action was NOT an insertion): The user did something else (navigation, deletion, etc). The lines above the cursor are likely stable, so we prefill up to the cursor line. This constrains the model to only edit at/below the cursor. We prefill everything before the cursor's line (up to the last newline before cursor position), so the model starts generating from the cursor line itself. Example: code_block is 11 lines, cursor on line 10 col 0. Prefill = lines 1-9 (everything up to the last \\n before cursor). Model generates lines 10-11. """ if changes_above_cursor: # --- Insertion mode: only prefill first line + trailing newlines --- prefill = code_block[:relative_cursor] prefilled_lines = prefill.splitlines(True) NUM_LINES_ABOVE = 1 before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE]) after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:]) # Append consecutive newlines (blank lines) but stop at first real char. # This preserves blank-line structure without constraining the model # to keep the original code on those lines. for char in after_split: if char == "\n": before_split += "\n" else: break return before_split else: # --- Default mode: prefill up to the cursor line --- prefix_before_cursor = code_block[:relative_cursor] if "\n" not in prefix_before_cursor: # Cursor is on the first line — no prefill possible return "" prefill_end = prefix_before_cursor.rfind("\n") + 1 return code_block[:prefill_end] def is_pure_insertion_above_cursor( code_block: str, completion: str, relative_cursor: int ) -> bool: """ Reject completions that only insert new lines above the cursor without actually editing the cursor line. These are low-value predictions — the model is just guessing what new code to add rather than fixing an existing reference. """ current_line_index = len(code_block[:relative_cursor].splitlines(True)) code_block_lines = code_block.splitlines(True) cursor_line = code_block_lines[current_line_index - 1] if code_block.strip() == completion.strip(): return False if not cursor_line.strip(): return False prefix_lines = code_block_lines[:current_line_index - 1] prefix = "".join(prefix_lines) suffix_lines = code_block_lines[current_line_index:] suffix = "".join(suffix_lines) # If completion = prefix + NEW STUFF + cursor_line + suffix, it's a pure # insertion above cursor (nothing at/below cursor changed). if completion.startswith(prefix) and completion.endswith(cursor_line + suffix): return True return False def build_prompt( file_path: str, file_contents: str, cursor_position: int, recent_changes: str = "", retrieval_chunks: list[FileChunk] | None = None, file_chunks: list[FileChunk] | None = None, changes_above_cursor: bool = False, num_lines_before: int = 10, num_lines_after: int = 10, ) -> tuple[str, str, int, int]: """ Build the model prompt from file contents and cursor position. Args: file_path: Path of the file being edited. file_contents: Full contents of the file after the user's latest edit. cursor_position: Character offset of the cursor in file_contents. recent_changes: Formatted diff string of recent changes (use DIFF_FORMAT). retrieval_chunks: Cross-file context chunks (e.g. related functions from other files). Placed AFTER recent_changes in the prompt for optimal KV cache reuse. file_chunks: Additional file context chunks. Prepended to the prompt. changes_above_cursor: Whether the user's last action was an insertion. Controls the prefill strategy (see compute_prefill). num_lines_before: Lines of code to include before cursor in the block. num_lines_after: Lines of code to include after cursor in the block. Returns: (formatted_prompt, code_block, block_start_index, relative_cursor) """ lines = file_contents.splitlines(True) # Find cursor line pos = 0 cursor_line = 0 for i, line in enumerate(lines): if pos + len(line) > cursor_position: cursor_line = i break pos += len(line) else: cursor_line = len(lines) - 1 # Extract code block around cursor block_start = max(0, cursor_line - num_lines_before) block_end = min(len(lines), cursor_line + num_lines_after + 1) code_block = "".join(lines[block_start:block_end]) block_start_index = sum(len(l) for l in lines[:block_start]) # Relative cursor position within code block relative_cursor = cursor_position - block_start_index # Insert <|cursor|> marker into the "current" version code_block_with_cursor = ( code_block[:relative_cursor] + "<|cursor|>" + code_block[relative_cursor:] ) # prev_section = code_block without cursor (the "original" version) prev_section = code_block # Compute prefill based on whether last action was an insertion prefill = compute_prefill(code_block, relative_cursor, changes_above_cursor) # initial_file: broad context around cursor from the file (up to ~300 lines) context_start = max(0, cursor_line - 150) context_end = min(len(lines), cursor_line + 150) initial_file = "".join(lines[context_start:context_end]) # Format retrieval results (cross-file context) retrieval_results = "" if retrieval_chunks: retrieval_results = "".join( f"\n{chunk.to_string()}" for chunk in retrieval_chunks ) start_line = block_start + 1 end_line = block_end formatted = PROMPT_TEMPLATE.format( file_path=file_path, initial_file=initial_file, retrieval_results=retrieval_results, recent_changes=recent_changes, prev_section=prev_section, code_block=code_block_with_cursor, start_line=start_line, end_line=end_line, prefill=prefill, ) # Prepend file chunks (other open files for context) if file_chunks: formatted = "".join(c.to_string() for c in file_chunks) + formatted return formatted, code_block, block_start_index, relative_cursor def generate(model, tokenizer, prompt: str, device: str = "cuda") -> str: """Run inference and return the completion (the predicted updated code block).""" inputs = tokenizer(prompt, return_tensors="pt").to(device) stop_token_ids = [ tokenizer.convert_tokens_to_ids(t) for t in STOP_TOKENS if t in tokenizer.get_vocab() ] eos_ids = list(set(stop_token_ids + [tokenizer.eos_token_id])) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False, # greedy (temperature=0) eos_token_id=eos_ids, pad_token_id=tokenizer.eos_token_id, ) new_tokens = outputs[0][inputs["input_ids"].shape[1]:] completion = tokenizer.decode(new_tokens, skip_special_tokens=False) # Strip stop tokens from output for stop in STOP_TOKENS: if stop in completion: completion = completion[: completion.index(stop)] return completion def main(): # --- Example: predict the next edit --- file_path = "example.py" file_contents = """\ def fibonacci(n): if n <= 0: return 0 elif n == 1: return 1 else: return fibonacci(n - 1) + fibonacci(n - 2) def main(): for i in range(10): print(fibonacci(i)) """ # Simulate: user just renamed fibonacci -> fib on line 7, # cursor is now on line 12 (the call site that still says fibonacci). edited_contents = file_contents.replace( "return fibonacci(n - 1) + fibonacci(n - 2)", "return fib(n - 1) + fib(n - 2)", ).replace( "def fibonacci(n):", "def fib(n):", ) # Cursor is on the print line that still references "fibonacci" cursor_line_text = " print(fibonacci(i))" cursor_position = edited_contents.index(cursor_line_text) # Recent change as a diff recent_changes = DIFF_FORMAT.format( file_path=file_path, start_line=1, end_line=7, old_code="def fibonacci(n):\n return fibonacci(n - 1) + fibonacci(n - 2)", new_code="def fib(n):\n return fib(n - 1) + fib(n - 2)", ) # Example retrieval chunk: a related function from another file retrieval_chunks = [ FileChunk( file_path="utils.py", content="def fib_memo(n, memo={}):\n if n in memo:\n return memo[n]\n memo[n] = fib_memo(n-1) + fib_memo(n-2)\n return memo[n]", ) ] # The rename was NOT an insertion, so changes_above_cursor=False. # This means the prefill will include everything up to the cursor line, # constraining the model to only edit at/below the cursor. prompt, code_block, block_start, relative_cursor = build_prompt( file_path=file_path, file_contents=edited_contents, cursor_position=cursor_position, recent_changes=recent_changes, retrieval_chunks=retrieval_chunks, changes_above_cursor=False, ) print("=" * 60) print("PROMPT") print("=" * 60) print(prompt) print() # --- Load model and run inference --- device = "mps" if torch.backends.mps.is_available() else "cpu" print(f"Loading model {MODEL_ID} on {device}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map=device, trust_remote_code=True, ) print("Running inference...") completion = generate(model, tokenizer, prompt, device=device) # Check for pure insertion above cursor (low-value prediction) if is_pure_insertion_above_cursor(code_block, completion, relative_cursor): print("Rejected: model only inserted above cursor without editing cursor line.") return print("=" * 60) print("MODEL OUTPUT (predicted updated code block)") print("=" * 60) print(completion) print() # Show the diff print("=" * 60) print("DIFF") print("=" * 60) print(f"Original code block:\n{code_block}") print(f"Updated code block:\n{completion}") if __name__ == "__main__": main()