Skip to content

Add git diff reward #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: sami/add-git-diff-reward
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"reasoning-gym @ git+https://github.com/open-thought/reasoning-gym.git",
"tomli>=2.2.1",
"pydantic[email]>=2.11.5",
"cydifflib>=1.2.0",
]

[project.optional-dependencies]
Expand Down
3 changes: 3 additions & 0 deletions src/zeroband/inference/genesys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from zeroband.inference.genesys.code import evaluate_code
from zeroband.inference.genesys.code_output_prediction import verify_code_output_prediction
from zeroband.inference.genesys.complex_json_output import verify_complex_json_formatting
from zeroband.inference.genesys.git_diff import compute_git_diff_reward
from zeroband.inference.genesys.ifeval import verify_ifeval
from zeroband.inference.genesys.math import compute_math_reward
from zeroband.inference.genesys.pydantic_json_adherance import validate_pydantic_json
Expand All @@ -22,6 +23,7 @@
"pydantic_adherance",
"ifeval",
"complex_json_output",
"git_diff",
]


Expand All @@ -43,4 +45,5 @@ def get_reward_function(task_type: TaskType) -> Callable[[str, dict], float]:
"pydantic_adherance": validate_pydantic_json,
"ifeval": verify_ifeval,
"complex_json_output": verify_complex_json_formatting,
"git_diff": compute_git_diff_reward,
}
101 changes: 101 additions & 0 deletions src/zeroband/inference/genesys/git_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Acknowledgements:
# SWE-Fixer: Training Open-Source LLMs for Effective and Efficient GitHub Issue Resolution
# Xie, Chengxing et al., 2025
#
# Agentless: Demystifying LLM-based Software Engineering Agents
# Xia, Chunqiu Steven et al., 2024
#
# SWE-RL: Advancing LLM Reasoning via Reinforcement Learning on Open Software Evolution
# Yuxiang Wei et al., 2025

import re
from typing import Dict

import cydifflib

DIFF_BLOCK_REGEX = re.compile(r"```diff\s*(.*?)\s*```", re.DOTALL)
INDEX_LINE_REGEX = re.compile(r"^index [^\n]*\n")
FUNC_CONTEXT_REGEX = re.compile(r"(?m)^(@@[^@]*@@).*")


def parse_last_diff_codeblock(markdown_str: str) -> str:
"""Extract the last ```diff``` code block from markdown text."""
matches = DIFF_BLOCK_REGEX.findall(markdown_str)
if matches:
return matches[-1].strip()
else:
return ""


def normalize_diff(diff_text: str) -> str:
"""
Normalize diff text by removing lines starting with 'index ...' and stripping function context after @@.
The function context/section header can differ between diffs, because language specific parsing might not be enabled.

Example:
```diff
diff --git a/file.py b/file.py
index 1234567890..1234567890
--- a/file.py
+++ b/file.py
@@ -15,1 +15,1 @@ def some_func():
- pass
+ return
```

becomes:
```diff
diff --git a/file.py b/file.py
--- a/file.py
+++ b/file.py
@@ -15,1 +15,1 @@
- pass
+ return
```
"""
diff_text = INDEX_LINE_REGEX.sub("", diff_text)
diff_text = FUNC_CONTEXT_REGEX.sub(r"\1", diff_text)
diff_text = diff_text.strip() + "\n"
return diff_text


def compute_git_diff_reward(completion: str, verification_info: Dict) -> float:
"""
Compute reward for git diff generation tasks using LCS (Longest Common Subsequence) ratio.

Args:
completion: Model's response string
verification_info: Dict containing golden_diff

Returns:
Float score (0.0 to 1.0) representing diff similarity
"""
# Extract the response after thinking (if present)
if "</think>" in completion:
response = completion.split("</think>")[1].strip()
else:
response = completion.strip()

if not response:
return 0.0

# Get expected answer from verification_info
golden_diff = verification_info.get("golden_diff", "")
if not golden_diff:
return 0.0

try:
# Extract diff from response
response_diff = parse_last_diff_codeblock(response)
response_diff = normalize_diff(response_diff)

if not response_diff.strip():
return 0.0

# Calculate LCS ratio
similarity = cydifflib.SequenceMatcher(None, response_diff, golden_diff, autojunk=False).ratio()

return similarity

except Exception:
return 0.0
Loading
Loading