Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"docker>=7.1.0",
"pynvml>=12.0.0",
"toploc>=0.0.2",
"cydifflib>=1.2.0",
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/genesys/verifiers/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from genesys.verifiers.math_verifier import MathVerifier
from genesys.verifiers.llm_judge_verifier import LlmJudgeVerifier
from genesys.verifiers.code_output_prediction_verifier import CodeUnderstandingVerifier
from genesys.verifiers.swe_fixer_verifier import SweFixerVerifier

VERIFIER_REGISTRY = {
"verifiable_code": CodeVerifier,
"verifiable_math": MathVerifier,
"llm_judgeable_groundtruth_similarity": LlmJudgeVerifier,
"code_output_prediction": CodeUnderstandingVerifier,
"swe_fixer": SweFixerVerifier,
}
204 changes: 204 additions & 0 deletions src/genesys/verifiers/swe_fixer_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
### adapted from https://github.com/InternLM/SWE-Fixer/blob/main/evaluation/code_edit.py
import json
import re
import ast
import argparse
import cydifflib

from genesys.schemas import Response
from genesys.verifiers.base_verifier import BaseVerifier


LINE_NUMBER_REGEX = re.compile(r"^\d+\s", re.MULTILINE)


def parse_json_codeblock_from_model_output(markdown_str):
# Get everything after </think>, if it exists
match = re.search(r"</think>(.*?)$", markdown_str, re.DOTALL)
answer_str = match.group(1).strip() if match else markdown_str.strip()
# Extract everything between ```json and ``` markers
match = re.search(r"```json\s*(.*?)\s*```", answer_str, re.DOTALL)
if match:
return match.group(1).strip()
else:
return answer_str.strip()


def remove_line_numbers(content):
return LINE_NUMBER_REGEX.sub("", content)


def remove_empty_lines(code):
lines = code.splitlines()
filtered_lines = [line for line in lines if line.strip() != ""]
return "\n".join(filtered_lines)


def check_syntax(code):
if not code.strip():
return False
try:
ast.parse(code)
except SyntaxError:
return False
return True


class SweFixerVerifier(BaseVerifier):
"""
Verifier for the SWE-Fixer dataset.
https://github.com/InternLM/SWE-Fixer
"""

def apply_patches(self, files_to_modify, patches):
"""
Apply a list of code-edit patches to an iterable of files and return the
fully-patched workspace.

Args:
files_to_modify (list[dict]): items from verification_info["input"]["files to be modified"]
patches (list[dict]): items structured like verification_info["output"]["edited code"]
or the model's JSON output.

Returns
-------
dict[str, str]
file-path -> patched file content
"""
workspace = {f["file"]: remove_line_numbers(f["file content"]) for f in files_to_modify}
failed_file_paths = []

for patch in patches:
file_path = patch["file"]
snippet_old = remove_line_numbers(patch["code snippet to be modified"]).strip()
snippet_new = patch["edited code snippet"].strip()

current = workspace.get(file_path, "")
if snippet_old:
if snippet_old not in current:
# Model failed to localize the code snippet to be modified
print("Model failed to localize the code snippet to be modified")
failed_file_paths.append(file_path)
continue
current = current.replace(snippet_old, snippet_new)
elif current == "": # brand-new file
current = snippet_new
workspace[file_path] = current

# Set the file content to None for files that failed to be patched
workspace = {k: None if k in failed_file_paths else v for k, v in workspace.items()}
return workspace

def get_diff(self, before, after):
diff = cydifflib.unified_diff(before.splitlines(), after.splitlines(), lineterm="")
lines = list(diff)[2:] # Keep relevant parts of the diff
return "\n".join(lines)

def score_patching(self, verification_info, json_output):
"""
Score how well the model's patches match the expected patches.

Args:
verification_info (dict): Contains the original files to modify and golden patches
in verification_info["input"]["files to be modified"] and
verification_info["output"]["edited code"] respectively
json_output (str): The model's output as a JSON string containing patches to apply

Returns:
dict: Contains:
- score (float): 1.0 if patches match exactly, 0.0 if syntax error or localization failure
- verification_result_info (dict): Additional info about verification failures
"""
try:
model_patches = json.loads(json_output)
except Exception as e:
return dict(score=0.0, verification_result_info={"failure_reason": f"Error in parsing JSON output: {e}"})

try:
original_files = verification_info["input"]["files to be modified"]
golden_patches = verification_info["output"]["edited code"]

original_workspace = self.apply_patches(original_files, [])
golden_workspace = self.apply_patches(original_files, golden_patches)
predicted_workspace = self.apply_patches(original_files, model_patches)
# predicted_workspace = self.apply_patches(original_files, golden_patches)

scores = []
for file_path in golden_workspace:
if predicted_workspace[file_path] is None:
scores.append(0.0) # model failed to localize edit location
continue
golden_file_content = golden_workspace[file_path]
predicted_file_content = predicted_workspace.get(file_path, "")

if predicted_file_content == golden_file_content:
scores.append(1.0)
continue

syntax_ok = check_syntax(predicted_file_content)
# syntax_ok_golden = check_syntax(golden_file_content)
if not syntax_ok:
return dict(score=0.0, verification_result_info={"failure_reason": "Syntax error"})
# if not syntax_ok_golden:
# return dict(score=0.0, verification_result_info={
# "failure_reason": "Syntax error in golden patch"
# })

golden_diff = self.get_diff(before=original_workspace[file_path], after=golden_file_content)
model_diff = self.get_diff(before=original_workspace[file_path], after=predicted_file_content)
# print("Line counts expected:")
# print(len(expected_file_content.splitlines()))
# print(len(golden_diff.splitlines()))
# print("Delta: ", len(expected_file_content.splitlines()) - len(golden_diff.splitlines()))
# print("Line counts predicted:")
# print(len(predicted_file_content.splitlines()))
# print(len(model_diff.splitlines()))
# print("Delta: ", len(predicted_file_content.splitlines()) - len(model_diff.splitlines()))
# print("MODEL DIFF:\n", model_diff)
# print("GOLDEN DIFF:\n", golden_diff)

score = cydifflib.SequenceMatcher(
None,
a=model_diff,
b=golden_diff,
autojunk=False,
).ratio()
scores.append(score)

return dict(score=sum(scores) / len(scores), verification_result_info=dict())

except Exception as e:
return dict(
score=0,
verification_result_info={"failure_reason": f"Error in scoring patches: {e}"},
)

def verify(self, result: Response):
"""
Evaluates the code patches by comparing the model's patches against golden patches.

The score lies between 0 and 1, representing the similarity between the model's patches and the golden patches.
"""
print("Processing example: ", result["problem_id"])
verification_info = result["verification_info"]
json_output = parse_json_codeblock_from_model_output(result["llm_response"])
return self.score_patching(verification_info, json_output)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Verify SWE-Fixer patches")
parser.add_argument("--file", type=str, required=True, help="Path to the input file containing patches to verify")
args = parser.parse_args()

to_verify = []
with open(args.file, "r") as f:
for line in f:
d = json.loads(line)
d["verification_info"] = ast.literal_eval(d["verification_info"])
d["metadata"] = ast.literal_eval(d["metadata"])
to_verify.append(d)

verifier = SweFixerVerifier()
for item in to_verify:
result = verifier.verify(item)
print(result)
136 changes: 136 additions & 0 deletions tests/verifiers/test_swe_fixer_verifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
from genesys.verifiers.registry import SweFixerVerifier


def test_swe_fixer_end_to_end_complex():
"""
End-to-end check of SweFixerVerifier with two files and wrong indentation.
The first patch fixes an indentation bug inside an `if`; the second renames
a variable and tweaks the greeting text.
"""

files_to_modify = [
{
"file": "utils/math_ops.py",
"file content": (
"1 def add(a, b):\n"
"2 return a + b\n"
"3 \n"
"4 def divide(a, b):\n"
"5 if b == 0:\n"
"6 return None\n" # missing indent needs to be fixed
"7 return a / b\n"
),
},
{
"file": "app/main.py",
"file content": (
"1 from utils.math_ops import add\n"
"2 \n"
"3 def greet(name):\n"
'4 message = f"Hello, {name}"\n'
"5 print(message)\n"
"6 \n"
'7 if __name__ == "__main__":\n'
"8 print(add(2,3))\n"
),
},
]

correct_patches = [
{
"file": "utils/math_ops.py",
"code snippet to be modified": (
"4 def divide(a, b):\n" "5 if b == 0:\n" "6 return None\n" "7 return a / b"
),
"edited code snippet": (
"def divide(a, b):\n"
" if b == 0:\n"
" return None\n" # fixed missing indent
" return a / b"
),
},
{
"file": "app/main.py",
"code snippet to be modified": (
"3 def greet(name):\n" '4 message = f"Hello, {name}"\n' "5 print(message)"
),
"edited code snippet": (
"def greet(name):\n"
' msg = f"Hello, {name}!"\n' # fixed variable name and greeting text
" print(msg)"
),
},
]

wrong_patch = [
{
"file": "utils/math_ops.py",
"code snippet to be modified": (
"4 def divide(a, b):\n" "5 if b == 0:\n" "6 return None\n" "7 return a / b"
),
"edited code snippet": (
"def divide(a, b):\n"
" if b == 0:\n"
" return None\n" # missing indent will cause syntax error
" return a / b"
),
}
]

slightly_wrong_patches = [
{
"file": "utils/math_ops.py",
"code snippet to be modified": (
"4 def divide(a, b):\n" "5 if b == 0:\n" "6 return None\n" "7 return a / b"
),
"edited code snippet": (
"def divide_two_numbers(a, b):\n" # wrong function name
" if b == 0:\n"
" return 0\n" # wrong return value
" return a / b"
),
},
{
"file": "app/main.py",
"code snippet to be modified": (
"3 def greet(name):\n" '4 message = f"Hello, {name}"\n' "5 print(message)"
),
"edited code snippet": (
"def greet_by_name(name):\n" # wrong function name
' message = f"Hello, {name}!"\n' # wrong variable name
" print(message)"
),
},
]

verification_info = {
"input": {"files to be modified": files_to_modify},
"output": {"edited code": correct_patches},
}

verifier = SweFixerVerifier()
result_correct = {
"problem_id": "test_swe_fixer_correct",
"verification_info": verification_info,
"llm_response": json.dumps(correct_patches),
}
result_wrong = {
"problem_id": "test_swe_fixer_wrong",
"verification_info": verification_info,
"llm_response": json.dumps(wrong_patch),
}
result_slightly_wrong = {
"problem_id": "test_swe_fixer_slightly_wrong",
"verification_info": verification_info,
"llm_response": json.dumps(slightly_wrong_patches),
}
score_dict_correct = verifier.verify(result_correct)
score_dict_wrong = verifier.verify(result_wrong)
score_dict_slightly_wrong = verifier.verify(result_slightly_wrong)

assert score_dict_correct["score"] == 1.0, f"unexpected verifier score: {score_dict_correct}"
assert score_dict_wrong["score"] == 0.0, f"unexpected verifier score: {score_dict_wrong}"
assert (
0.5 < score_dict_slightly_wrong["score"] < 1.0
), f"unexpected verifier score: {score_dict_slightly_wrong}" # 0.9334
Loading