初始化项目,由ModelHub XC社区提供模型

Model: sweepai/sweep-next-edit-v2-7B
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-04-11 08:28:57 +08:00
commit 5da26c07e9
16 changed files with 152614 additions and 0 deletions

386
inference.py Normal file
View File

@@ -0,0 +1,386 @@
"""
Minimal reproducible inference script for sweep-next-edit-v2-7B.
This model predicts the next edit a developer will make given:
- the current file contents
- recent changes (diffs)
- the cursor position
- (optional) retrieval chunks from other files
Usage:
python inference.py
Requires: transformers, torch, accelerate
pip install transformers torch accelerate
"""
import torch
from dataclasses import dataclass
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_ID = "sweepai/sweep-next-edit-v2-7B"
# --- Prompt template (from sweepai/autocomplete/next_edit_autocomplete.py) ---
PROMPT_TEMPLATE = """<|file_sep|>{file_path}
{initial_file}{retrieval_results}
{recent_changes}
<|file_sep|>original/{file_path}:{start_line}:{end_line}
{prev_section}
<|file_sep|>current/{file_path}:{start_line}:{end_line}
{code_block}
<|file_sep|>updated/{file_path}:{start_line}:{end_line}
{prefill}"""
DIFF_FORMAT = """<|file_sep|>{file_path}:{start_line}:{end_line}
original:
{old_code}
updated:
{new_code}"""
STOP_TOKENS = ["<|endoftext|>", "<|file_sep|>"]
MAX_NEW_TOKENS = 1024
@dataclass
class FileChunk:
"""A chunk of code from another file, used for cross-file context (retrieval)."""
file_path: str
content: str
def to_string(self) -> str:
return f"<|file_sep|>{self.file_path}\n{self.content}\n"
def compute_prefill(
code_block: str,
relative_cursor: int,
changes_above_cursor: bool = False,
) -> str:
"""
Compute the prefill string — the portion of the updated code block that we
feed to the model so it only has to generate starting from the edit point.
The model's job is to produce the full "updated" code block. But most of it
is unchanged — only a small region near the cursor is different. So we
"prefill" the output with the unchanged prefix, and the model just continues
from there.
Two strategies depending on what the user just did:
changes_above_cursor=True (last action was an insertion):
The user just inserted text above the cursor. The lines above the cursor
may have shifted, so we can't trust them as a prefill — the model might
need to edit them. We only prefill the very first line of the code block
(plus any blank lines after it), giving the model freedom to rewrite
everything from line 2 onward.
Example: code_block is 11 lines, cursor on line 10.
Prefill = line 1 + any trailing blank lines = " if n <= 0:\n"
Model generates lines 2-11.
changes_above_cursor=False (last action was NOT an insertion):
The user did something else (navigation, deletion, etc). The lines above
the cursor are likely stable, so we prefill up to the cursor line. This
constrains the model to only edit at/below the cursor.
We prefill everything before the cursor's line (up to the last newline
before cursor position), so the model starts generating from the cursor
line itself.
Example: code_block is 11 lines, cursor on line 10 col 0.
Prefill = lines 1-9 (everything up to the last \\n before cursor).
Model generates lines 10-11.
"""
if changes_above_cursor:
# --- Insertion mode: only prefill first line + trailing newlines ---
prefill = code_block[:relative_cursor]
prefilled_lines = prefill.splitlines(True)
NUM_LINES_ABOVE = 1
before_split = "".join(prefilled_lines[:NUM_LINES_ABOVE])
after_split = "".join(prefilled_lines[NUM_LINES_ABOVE:])
# Append consecutive newlines (blank lines) but stop at first real char.
# This preserves blank-line structure without constraining the model
# to keep the original code on those lines.
for char in after_split:
if char == "\n":
before_split += "\n"
else:
break
return before_split
else:
# --- Default mode: prefill up to the cursor line ---
prefix_before_cursor = code_block[:relative_cursor]
if "\n" not in prefix_before_cursor:
# Cursor is on the first line — no prefill possible
return ""
prefill_end = prefix_before_cursor.rfind("\n") + 1
return code_block[:prefill_end]
def is_pure_insertion_above_cursor(
code_block: str, completion: str, relative_cursor: int
) -> bool:
"""
Reject completions that only insert new lines above the cursor without
actually editing the cursor line. These are low-value predictions —
the model is just guessing what new code to add rather than fixing
an existing reference.
"""
current_line_index = len(code_block[:relative_cursor].splitlines(True))
code_block_lines = code_block.splitlines(True)
cursor_line = code_block_lines[current_line_index - 1]
if code_block.strip() == completion.strip():
return False
if not cursor_line.strip():
return False
prefix_lines = code_block_lines[:current_line_index - 1]
prefix = "".join(prefix_lines)
suffix_lines = code_block_lines[current_line_index:]
suffix = "".join(suffix_lines)
# If completion = prefix + NEW STUFF + cursor_line + suffix, it's a pure
# insertion above cursor (nothing at/below cursor changed).
if completion.startswith(prefix) and completion.endswith(cursor_line + suffix):
return True
return False
def build_prompt(
file_path: str,
file_contents: str,
cursor_position: int,
recent_changes: str = "",
retrieval_chunks: list[FileChunk] | None = None,
file_chunks: list[FileChunk] | None = None,
changes_above_cursor: bool = False,
num_lines_before: int = 10,
num_lines_after: int = 10,
) -> tuple[str, str, int, int]:
"""
Build the model prompt from file contents and cursor position.
Args:
file_path: Path of the file being edited.
file_contents: Full contents of the file after the user's latest edit.
cursor_position: Character offset of the cursor in file_contents.
recent_changes: Formatted diff string of recent changes (use DIFF_FORMAT).
retrieval_chunks: Cross-file context chunks (e.g. related functions from
other files). Placed AFTER recent_changes in the prompt for optimal
KV cache reuse.
file_chunks: Additional file context chunks. Prepended to the prompt.
changes_above_cursor: Whether the user's last action was an insertion.
Controls the prefill strategy (see compute_prefill).
num_lines_before: Lines of code to include before cursor in the block.
num_lines_after: Lines of code to include after cursor in the block.
Returns:
(formatted_prompt, code_block, block_start_index, relative_cursor)
"""
lines = file_contents.splitlines(True)
# Find cursor line
pos = 0
cursor_line = 0
for i, line in enumerate(lines):
if pos + len(line) > cursor_position:
cursor_line = i
break
pos += len(line)
else:
cursor_line = len(lines) - 1
# Extract code block around cursor
block_start = max(0, cursor_line - num_lines_before)
block_end = min(len(lines), cursor_line + num_lines_after + 1)
code_block = "".join(lines[block_start:block_end])
block_start_index = sum(len(l) for l in lines[:block_start])
# Relative cursor position within code block
relative_cursor = cursor_position - block_start_index
# Insert <|cursor|> marker into the "current" version
code_block_with_cursor = (
code_block[:relative_cursor]
+ "<|cursor|>"
+ code_block[relative_cursor:]
)
# prev_section = code_block without cursor (the "original" version)
prev_section = code_block
# Compute prefill based on whether last action was an insertion
prefill = compute_prefill(code_block, relative_cursor, changes_above_cursor)
# initial_file: broad context around cursor from the file (up to ~300 lines)
context_start = max(0, cursor_line - 150)
context_end = min(len(lines), cursor_line + 150)
initial_file = "".join(lines[context_start:context_end])
# Format retrieval results (cross-file context)
retrieval_results = ""
if retrieval_chunks:
retrieval_results = "".join(
f"\n{chunk.to_string()}" for chunk in retrieval_chunks
)
start_line = block_start + 1
end_line = block_end
formatted = PROMPT_TEMPLATE.format(
file_path=file_path,
initial_file=initial_file,
retrieval_results=retrieval_results,
recent_changes=recent_changes,
prev_section=prev_section,
code_block=code_block_with_cursor,
start_line=start_line,
end_line=end_line,
prefill=prefill,
)
# Prepend file chunks (other open files for context)
if file_chunks:
formatted = "".join(c.to_string() for c in file_chunks) + formatted
return formatted, code_block, block_start_index, relative_cursor
def generate(model, tokenizer, prompt: str, device: str = "cuda") -> str:
"""Run inference and return the completion (the predicted updated code block)."""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
stop_token_ids = [
tokenizer.convert_tokens_to_ids(t)
for t in STOP_TOKENS
if t in tokenizer.get_vocab()
]
eos_ids = list(set(stop_token_ids + [tokenizer.eos_token_id]))
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
do_sample=False, # greedy (temperature=0)
eos_token_id=eos_ids,
pad_token_id=tokenizer.eos_token_id,
)
new_tokens = outputs[0][inputs["input_ids"].shape[1]:]
completion = tokenizer.decode(new_tokens, skip_special_tokens=False)
# Strip stop tokens from output
for stop in STOP_TOKENS:
if stop in completion:
completion = completion[: completion.index(stop)]
return completion
def main():
# --- Example: predict the next edit ---
file_path = "example.py"
file_contents = """\
def fibonacci(n):
if n <= 0:
return 0
elif n == 1:
return 1
else:
return fibonacci(n - 1) + fibonacci(n - 2)
def main():
for i in range(10):
print(fibonacci(i))
"""
# Simulate: user just renamed fibonacci -> fib on line 7,
# cursor is now on line 12 (the call site that still says fibonacci).
edited_contents = file_contents.replace(
"return fibonacci(n - 1) + fibonacci(n - 2)",
"return fib(n - 1) + fib(n - 2)",
).replace(
"def fibonacci(n):",
"def fib(n):",
)
# Cursor is on the print line that still references "fibonacci"
cursor_line_text = " print(fibonacci(i))"
cursor_position = edited_contents.index(cursor_line_text)
# Recent change as a diff
recent_changes = DIFF_FORMAT.format(
file_path=file_path,
start_line=1,
end_line=7,
old_code="def fibonacci(n):\n return fibonacci(n - 1) + fibonacci(n - 2)",
new_code="def fib(n):\n return fib(n - 1) + fib(n - 2)",
)
# Example retrieval chunk: a related function from another file
retrieval_chunks = [
FileChunk(
file_path="utils.py",
content="def fib_memo(n, memo={}):\n if n in memo:\n return memo[n]\n memo[n] = fib_memo(n-1) + fib_memo(n-2)\n return memo[n]",
)
]
# The rename was NOT an insertion, so changes_above_cursor=False.
# This means the prefill will include everything up to the cursor line,
# constraining the model to only edit at/below the cursor.
prompt, code_block, block_start, relative_cursor = build_prompt(
file_path=file_path,
file_contents=edited_contents,
cursor_position=cursor_position,
recent_changes=recent_changes,
retrieval_chunks=retrieval_chunks,
changes_above_cursor=False,
)
print("=" * 60)
print("PROMPT")
print("=" * 60)
print(prompt)
print()
# --- Load model and run inference ---
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Loading model {MODEL_ID} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map=device,
trust_remote_code=True,
)
print("Running inference...")
completion = generate(model, tokenizer, prompt, device=device)
# Check for pure insertion above cursor (low-value prediction)
if is_pure_insertion_above_cursor(code_block, completion, relative_cursor):
print("Rejected: model only inserted above cursor without editing cursor line.")
return
print("=" * 60)
print("MODEL OUTPUT (predicted updated code block)")
print("=" * 60)
print(completion)
print()
# Show the diff
print("=" * 60)
print("DIFF")
print("=" * 60)
print(f"Original code block:\n{code_block}")
print(f"Updated code block:\n{completion}")
if __name__ == "__main__":
main()