初始化项目,由ModelHub XC社区提供模型
Model: sweepai/sweep-next-edit-v2-7B Source: Original Platform
This commit is contained in:
386
inference.py
Normal file
386
inference.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Minimal reproducible inference script for sweep-next-edit-v2-7B.
|
||||
|
||||
This model predicts the next edit a developer will make given:
|
||||
- the current file contents
|
||||
- recent changes (diffs)
|
||||
- the cursor position
|
||||
- (optional) retrieval chunks from other files
|
||||
|
||||
Usage:
|
||||
python inference.py
|
||||
|
||||
Requires: transformers, torch, accelerate
|
||||
pip install transformers torch accelerate
|
||||
"""
|
||||
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
MODEL_ID = "sweepai/sweep-next-edit-v2-7B"
|
||||
|
||||
# --- Prompt template (from sweepai/autocomplete/next_edit_autocomplete.py) ---
|
||||
PROMPT_TEMPLATE = """<|file_sep|>{file_path}
|
||||
{initial_file}{retrieval_results}
|
||||
{recent_changes}
|
||||
<|file_sep|>original/{file_path}:{start_line}:{end_line}
|
||||
{prev_section}
|
||||
<|file_sep|>current/{file_path}:{start_line}:{end_line}
|
||||
{code_block}
|
||||
<|file_sep|>updated/{file_path}:{start_line}:{end_line}
|
||||
{prefill}"""
|
||||
|
||||
DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line}
|
||||
original:
|
||||
{old_code}
|
||||
updated:
|
||||
{new_code}"""
|
||||
|
||||
STOP_TOKENS = ["<|endoftext|>", "<|file_sep|>"]
|
||||
MAX_NEW_TOKENS = 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileChunk:
|
||||
"""A chunk of code from another file, used for cross-file context (retrieval)."""
|
||||
file_path: str
|
||||
content: str
|
||||
|
||||
def to_string(self) -> str:
|
||||
return f"<|file_sep|>{self.file_path}\n{self.content}\n"
|
||||
|
||||
|
||||
def compute_prefill(
|
||||
code_block: str,
|
||||
relative_cursor: int,
|
||||
changes_above_cursor: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Compute the prefill string — the portion of the updated code block that we
|
||||
feed to the model so it only has to generate starting from the edit point.
|
||||
|
||||
The model's job is to produce the full "updated" code block. But most of it
|
||||
is unchanged — only a small region near the cursor is different. So we
|
||||
"prefill" the output with the unchanged prefix, and the model just continues
|
||||
from there.
|
||||
|
||||
Two strategies depending on what the user just did:
|
||||
|
||||
changes_above_cursor=True (last action was an insertion):
|
||||
The user just inserted text above the cursor. The lines above the cursor
|
||||
may have shifted, so we can't trust them as a prefill — the model might
|
||||
need to edit them. We only prefill the very first line of the code block
|
||||
(plus any blank lines after it), giving the model freedom to rewrite
|
||||
everything from line 2 onward.
|
||||
|
||||
Example: code_block is 11 lines, cursor on line 10.
|
||||
Prefill = line 1 + any trailing blank lines = " if n <= 0:\n"
|
||||
Model generates lines 2-11.
|
||||
|
||||
changes_above_cursor=False (last action was NOT an insertion):
|
||||
The user did something else (navigation, deletion, etc). The lines above
|
||||
the cursor are likely stable, so we prefill up to the cursor line. This
|
||||
constrains the model to only edit at/below the cursor.
|
||||
|
||||
We prefill everything before the cursor's line (up to the last newline
|
||||
before cursor position), so the model starts generating from the cursor
|
||||
line itself.
|
||||
|
||||
Example: code_block is 11 lines, cursor on line 10 col 0.
|
||||
Prefill = lines 1-9 (everything up to the last \\n before cursor).
|
||||
Model generates lines 10-11.
|
||||
"""
|
||||
if changes_above_cursor:
|
||||
# --- Insertion mode: only prefill first line + trailing newlines ---
|
||||
prefill = code_block[:relative_cursor]
|
||||
prefilled_lines = prefill.splitlines(True)
|
||||
|
||||
NUM_LINES_ABOVE = 1
|
||||
before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE])
|
||||
after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:])
|
||||
|
||||
# Append consecutive newlines (blank lines) but stop at first real char.
|
||||
# This preserves blank-line structure without constraining the model
|
||||
# to keep the original code on those lines.
|
||||
for char in after_split:
|
||||
if char == "\n":
|
||||
before_split += "\n"
|
||||
else:
|
||||
break
|
||||
|
||||
return before_split
|
||||
else:
|
||||
# --- Default mode: prefill up to the cursor line ---
|
||||
prefix_before_cursor = code_block[:relative_cursor]
|
||||
if "\n" not in prefix_before_cursor:
|
||||
# Cursor is on the first line — no prefill possible
|
||||
return ""
|
||||
prefill_end = prefix_before_cursor.rfind("\n") + 1
|
||||
return code_block[:prefill_end]
|
||||
|
||||
|
||||
def is_pure_insertion_above_cursor(
|
||||
code_block: str, completion: str, relative_cursor: int
|
||||
) -> bool:
|
||||
"""
|
||||
Reject completions that only insert new lines above the cursor without
|
||||
actually editing the cursor line. These are low-value predictions —
|
||||
the model is just guessing what new code to add rather than fixing
|
||||
an existing reference.
|
||||
"""
|
||||
current_line_index = len(code_block[:relative_cursor].splitlines(True))
|
||||
code_block_lines = code_block.splitlines(True)
|
||||
cursor_line = code_block_lines[current_line_index - 1]
|
||||
|
||||
if code_block.strip() == completion.strip():
|
||||
return False
|
||||
if not cursor_line.strip():
|
||||
return False
|
||||
|
||||
prefix_lines = code_block_lines[:current_line_index - 1]
|
||||
prefix = "".join(prefix_lines)
|
||||
suffix_lines = code_block_lines[current_line_index:]
|
||||
suffix = "".join(suffix_lines)
|
||||
|
||||
# If completion = prefix + NEW STUFF + cursor_line + suffix, it's a pure
|
||||
# insertion above cursor (nothing at/below cursor changed).
|
||||
if completion.startswith(prefix) and completion.endswith(cursor_line + suffix):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def build_prompt(
|
||||
file_path: str,
|
||||
file_contents: str,
|
||||
cursor_position: int,
|
||||
recent_changes: str = "",
|
||||
retrieval_chunks: list[FileChunk] | None = None,
|
||||
file_chunks: list[FileChunk] | None = None,
|
||||
changes_above_cursor: bool = False,
|
||||
num_lines_before: int = 10,
|
||||
num_lines_after: int = 10,
|
||||
) -> tuple[str, str, int, int]:
|
||||
"""
|
||||
Build the model prompt from file contents and cursor position.
|
||||
|
||||
Args:
|
||||
file_path: Path of the file being edited.
|
||||
file_contents: Full contents of the file after the user's latest edit.
|
||||
cursor_position: Character offset of the cursor in file_contents.
|
||||
recent_changes: Formatted diff string of recent changes (use DIFF_FORMAT).
|
||||
retrieval_chunks: Cross-file context chunks (e.g. related functions from
|
||||
other files). Placed AFTER recent_changes in the prompt for optimal
|
||||
KV cache reuse.
|
||||
file_chunks: Additional file context chunks. Prepended to the prompt.
|
||||
changes_above_cursor: Whether the user's last action was an insertion.
|
||||
Controls the prefill strategy (see compute_prefill).
|
||||
num_lines_before: Lines of code to include before cursor in the block.
|
||||
num_lines_after: Lines of code to include after cursor in the block.
|
||||
|
||||
Returns:
|
||||
(formatted_prompt, code_block, block_start_index, relative_cursor)
|
||||
"""
|
||||
lines = file_contents.splitlines(True)
|
||||
|
||||
# Find cursor line
|
||||
pos = 0
|
||||
cursor_line = 0
|
||||
for i, line in enumerate(lines):
|
||||
if pos + len(line) > cursor_position:
|
||||
cursor_line = i
|
||||
break
|
||||
pos += len(line)
|
||||
else:
|
||||
cursor_line = len(lines) - 1
|
||||
|
||||
# Extract code block around cursor
|
||||
block_start = max(0, cursor_line - num_lines_before)
|
||||
block_end = min(len(lines), cursor_line + num_lines_after + 1)
|
||||
code_block = "".join(lines[block_start:block_end])
|
||||
block_start_index = sum(len(l) for l in lines[:block_start])
|
||||
|
||||
# Relative cursor position within code block
|
||||
relative_cursor = cursor_position - block_start_index
|
||||
|
||||
# Insert <|cursor|> marker into the "current" version
|
||||
code_block_with_cursor = (
|
||||
code_block[:relative_cursor]
|
||||
+ "<|cursor|>"
|
||||
+ code_block[relative_cursor:]
|
||||
)
|
||||
|
||||
# prev_section = code_block without cursor (the "original" version)
|
||||
prev_section = code_block
|
||||
|
||||
# Compute prefill based on whether last action was an insertion
|
||||
prefill = compute_prefill(code_block, relative_cursor, changes_above_cursor)
|
||||
|
||||
# initial_file: broad context around cursor from the file (up to ~300 lines)
|
||||
context_start = max(0, cursor_line - 150)
|
||||
context_end = min(len(lines), cursor_line + 150)
|
||||
initial_file = "".join(lines[context_start:context_end])
|
||||
|
||||
# Format retrieval results (cross-file context)
|
||||
retrieval_results = ""
|
||||
if retrieval_chunks:
|
||||
retrieval_results = "".join(
|
||||
f"\n{chunk.to_string()}" for chunk in retrieval_chunks
|
||||
)
|
||||
|
||||
start_line = block_start + 1
|
||||
end_line = block_end
|
||||
|
||||
formatted = PROMPT_TEMPLATE.format(
|
||||
file_path=file_path,
|
||||
initial_file=initial_file,
|
||||
retrieval_results=retrieval_results,
|
||||
recent_changes=recent_changes,
|
||||
prev_section=prev_section,
|
||||
code_block=code_block_with_cursor,
|
||||
start_line=start_line,
|
||||
end_line=end_line,
|
||||
prefill=prefill,
|
||||
)
|
||||
|
||||
# Prepend file chunks (other open files for context)
|
||||
if file_chunks:
|
||||
formatted = "".join(c.to_string() for c in file_chunks) + formatted
|
||||
|
||||
return formatted, code_block, block_start_index, relative_cursor
|
||||
|
||||
|
||||
def generate(model, tokenizer, prompt: str, device: str = "cuda") -> str:
|
||||
"""Run inference and return the completion (the predicted updated code block)."""
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
stop_token_ids = [
|
||||
tokenizer.convert_tokens_to_ids(t)
|
||||
for t in STOP_TOKENS
|
||||
if t in tokenizer.get_vocab()
|
||||
]
|
||||
eos_ids = list(set(stop_token_ids + [tokenizer.eos_token_id]))
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
do_sample=False, # greedy (temperature=0)
|
||||
eos_token_id=eos_ids,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
|
||||
completion = tokenizer.decode(new_tokens, skip_special_tokens=False)
|
||||
|
||||
# Strip stop tokens from output
|
||||
for stop in STOP_TOKENS:
|
||||
if stop in completion:
|
||||
completion = completion[: completion.index(stop)]
|
||||
|
||||
return completion
|
||||
|
||||
|
||||
def main():
|
||||
# --- Example: predict the next edit ---
|
||||
file_path = "example.py"
|
||||
file_contents = """\
|
||||
def fibonacci(n):
|
||||
if n <= 0:
|
||||
return 0
|
||||
elif n == 1:
|
||||
return 1
|
||||
else:
|
||||
return fibonacci(n - 1) + fibonacci(n - 2)
|
||||
|
||||
|
||||
def main():
|
||||
for i in range(10):
|
||||
print(fibonacci(i))
|
||||
"""
|
||||
|
||||
# Simulate: user just renamed fibonacci -> fib on line 7,
|
||||
# cursor is now on line 12 (the call site that still says fibonacci).
|
||||
edited_contents = file_contents.replace(
|
||||
"return fibonacci(n - 1) + fibonacci(n - 2)",
|
||||
"return fib(n - 1) + fib(n - 2)",
|
||||
).replace(
|
||||
"def fibonacci(n):",
|
||||
"def fib(n):",
|
||||
)
|
||||
|
||||
# Cursor is on the print line that still references "fibonacci"
|
||||
cursor_line_text = " print(fibonacci(i))"
|
||||
cursor_position = edited_contents.index(cursor_line_text)
|
||||
|
||||
# Recent change as a diff
|
||||
recent_changes = DIFF_FORMAT.format(
|
||||
file_path=file_path,
|
||||
start_line=1,
|
||||
end_line=7,
|
||||
old_code="def fibonacci(n):\n return fibonacci(n - 1) + fibonacci(n - 2)",
|
||||
new_code="def fib(n):\n return fib(n - 1) + fib(n - 2)",
|
||||
)
|
||||
|
||||
# Example retrieval chunk: a related function from another file
|
||||
retrieval_chunks = [
|
||||
FileChunk(
|
||||
file_path="utils.py",
|
||||
content="def fib_memo(n, memo={}):\n if n in memo:\n return memo[n]\n memo[n] = fib_memo(n-1) + fib_memo(n-2)\n return memo[n]",
|
||||
)
|
||||
]
|
||||
|
||||
# The rename was NOT an insertion, so changes_above_cursor=False.
|
||||
# This means the prefill will include everything up to the cursor line,
|
||||
# constraining the model to only edit at/below the cursor.
|
||||
prompt, code_block, block_start, relative_cursor = build_prompt(
|
||||
file_path=file_path,
|
||||
file_contents=edited_contents,
|
||||
cursor_position=cursor_position,
|
||||
recent_changes=recent_changes,
|
||||
retrieval_chunks=retrieval_chunks,
|
||||
changes_above_cursor=False,
|
||||
)
|
||||
|
||||
print("=" * 60)
|
||||
print("PROMPT")
|
||||
print("=" * 60)
|
||||
print(prompt)
|
||||
print()
|
||||
|
||||
# --- Load model and run inference ---
|
||||
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
print(f"Loading model {MODEL_ID} on {device}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
MODEL_ID,
|
||||
dtype=torch.bfloat16,
|
||||
device_map=device,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
print("Running inference...")
|
||||
completion = generate(model, tokenizer, prompt, device=device)
|
||||
|
||||
# Check for pure insertion above cursor (low-value prediction)
|
||||
if is_pure_insertion_above_cursor(code_block, completion, relative_cursor):
|
||||
print("Rejected: model only inserted above cursor without editing cursor line.")
|
||||
return
|
||||
|
||||
print("=" * 60)
|
||||
print("MODEL OUTPUT (predicted updated code block)")
|
||||
print("=" * 60)
|
||||
print(completion)
|
||||
print()
|
||||
|
||||
# Show the diff
|
||||
print("=" * 60)
|
||||
print("DIFF")
|
||||
print("=" * 60)
|
||||
print(f"Original code block:\n{code_block}")
|
||||
print(f"Updated code block:\n{completion}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user