Skip to content

Add general difflib reward function + format extraction task #363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions configs/inference/formatask.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
dataset = "kalomaze/general-formatask-it2"
batch_size = 256
dp = 4
rollout_path = "outputs"
clean_output_path = true
output_path = "data_rollout"
max_model_len = 2048

[sampling]
temperature = 1.0
n = 16
24 changes: 24 additions & 0 deletions configs/training/formatask.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
project = "formatask-debug"

[train]
micro_bs = 2 # change to 8 for H200
reshard_after_forward = true

[optim]
batch_size = 64
warmup_steps = 20
total_steps = 100000000000000
step_per_rollout = 4
grad_norm_clip = 0.0001

[optim.optim]
lr = 5e-6

[data]
path = "data_rollout"
seq_length = 2048

[ckpt]
rollout_path = "outputs"
clean_rollout_path = true
3 changes: 3 additions & 0 deletions src/zeroband/inference/genesys/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from zeroband.inference.genesys.code import evaluate_code
from zeroband.inference.genesys.code_output_prediction import verify_code_output_prediction
from zeroband.inference.genesys.complex_json_output import verify_complex_json_formatting
from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward
from zeroband.inference.genesys.math import compute_math_reward
from zeroband.inference.genesys.pydantic_json_adherance import validate_pydantic_json
from zeroband.inference.genesys.reasoning_gym import verify_reasoning_gym
Expand All @@ -20,6 +21,7 @@
"ascii_tree_formatting",
"pydantic_adherance",
"complex_json_output",
"formatask",
]


Expand All @@ -40,4 +42,5 @@ def get_reward_function(task_type: TaskType) -> Callable[[str, dict], float]:
"ascii_tree_formatting": compute_ascii_tree_reward,
"pydantic_adherance": validate_pydantic_json,
"complex_json_output": verify_complex_json_formatting,
"formatask": compute_formatask_reward,
}
99 changes: 99 additions & 0 deletions src/zeroband/inference/genesys/formatask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import difflib
import re
from typing import Dict


def detect_format_signature(text: str) -> tuple:
signature = []
current_text = text

# Check for markdown emphasis (changed to if statements for layering)
if current_text.startswith("***") and current_text.endswith("***") and len(current_text) > 6:
signature.append("triple_asterisk")
current_text = current_text[3:-3]
elif current_text.startswith("**") and current_text.endswith("**") and len(current_text) > 4:
signature.append("bold")
current_text = current_text[2:-2]
elif current_text.startswith("*") and current_text.endswith("*") and len(current_text) > 2:
signature.append("italic")
current_text = current_text[1:-1]

# Check for code blocks
if current_text.startswith("```\n") and current_text.endswith("\n```"):
signature.append("code_block")
current_text = current_text[4:-4]

# Check for wrapper characters (changed to if statements for layering)
if current_text.startswith('"') and current_text.endswith('"') and len(current_text) > 2:
signature.append("quotes")
current_text = current_text[1:-1]
elif current_text.startswith("[") and current_text.endswith("]") and len(current_text) > 2:
signature.append("brackets")
current_text = current_text[1:-1]
elif current_text.startswith("(") and current_text.endswith(")") and len(current_text) > 2:
signature.append("parentheses")
current_text = current_text[1:-1]
elif current_text.startswith("-") and current_text.endswith("-") and len(current_text) > 2:
signature.append("dashes")
current_text = current_text[1:-1]
elif current_text.startswith("<") and current_text.endswith(">") and len(current_text) > 2:
signature.append("angle_brackets")
current_text = current_text[1:-1]

# Check for uppercase (this can layer with other formats)
if current_text.isupper() and len(current_text) > 0:
signature.append("uppercase")

return tuple(signature) if signature else ("none",)


def extract_and_score(text: str, tag_name: str, ground_truth: str) -> float:
pattern = f"<{tag_name}>(.*?)</{tag_name}>"
match = re.search(pattern, text, re.DOTALL)

if not match:
return 0

extracted = match.group(1)

# Check format compliance
gt_sig = detect_format_signature(ground_truth)
ext_sig = detect_format_signature(extracted)

# Calculate similarity
similarity = difflib.SequenceMatcher(None, extracted, ground_truth).ratio()

# If format is wrong but content is extractable, give 0.1x reward
if gt_sig != ("none",) and gt_sig != ext_sig:
return similarity * 0.1

return similarity


def compute_reward(completion: str, verification_info: Dict, tag_name: str = "extracted_formatted") -> float:
try:
# Skip thinking section
text = completion
think_end = completion.find("</think>")
if think_end != -1:
text = completion[think_end + len("</think>") :]

# Check for dual case first
if "ground_truth1" in verification_info and "ground_truth2" in verification_info:
score1 = extract_and_score(text, "extracted_formatted1", verification_info["ground_truth1"])
score2 = extract_and_score(text, "extracted_formatted2", verification_info["ground_truth2"])

if score1 == 0 or score2 == 0:
return 0

return (score1 + score2) / 2.0

# Single case
ground_truth = verification_info.get("ground_truth")
if not ground_truth:
return 0

return extract_and_score(text, tag_name, ground_truth)

except Exception:
return 0