diff --git a/configs/inference/formatask.toml b/configs/inference/formatask.toml new file mode 100644 index 00000000..7b324e89 --- /dev/null +++ b/configs/inference/formatask.toml @@ -0,0 +1,12 @@ +model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" +dataset = "kalomaze/general-formatask-it2" +batch_size = 256 +dp = 4 +rollout_path = "outputs" +clean_output_path = true +output_path = "data_rollout" +max_model_len = 2048 + +[sampling] +temperature = 1.0 +n = 16 \ No newline at end of file diff --git a/configs/training/formatask.toml b/configs/training/formatask.toml new file mode 100644 index 00000000..45e33001 --- /dev/null +++ b/configs/training/formatask.toml @@ -0,0 +1,24 @@ +model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" +project = "formatask-debug" + +[train] +micro_bs = 2 # change to 8 for H200 +reshard_after_forward = true + +[optim] +batch_size = 64 +warmup_steps = 20 +total_steps = 100000000000000 +step_per_rollout = 4 +grad_norm_clip = 0.0001 + +[optim.optim] +lr = 5e-6 + +[data] +path = "data_rollout" +seq_length = 2048 + +[ckpt] +rollout_path = "outputs" +clean_rollout_path = true \ No newline at end of file diff --git a/src/zeroband/inference/genesys/__init__.py b/src/zeroband/inference/genesys/__init__.py index ee7e955a..6aff91d5 100644 --- a/src/zeroband/inference/genesys/__init__.py +++ b/src/zeroband/inference/genesys/__init__.py @@ -4,6 +4,7 @@ from zeroband.inference.genesys.code import evaluate_code from zeroband.inference.genesys.code_output_prediction import verify_code_output_prediction from zeroband.inference.genesys.complex_json_output import verify_complex_json_formatting +from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward from zeroband.inference.genesys.math import compute_math_reward from zeroband.inference.genesys.pydantic_json_adherance import validate_pydantic_json from zeroband.inference.genesys.reasoning_gym import verify_reasoning_gym @@ -20,6 +21,7 @@ "ascii_tree_formatting", "pydantic_adherance", "complex_json_output", + "formatask", ] @@ -40,4 +42,5 @@ def get_reward_function(task_type: TaskType) -> Callable[[str, dict], float]: "ascii_tree_formatting": compute_ascii_tree_reward, "pydantic_adherance": validate_pydantic_json, "complex_json_output": verify_complex_json_formatting, + "formatask": compute_formatask_reward, } diff --git a/src/zeroband/inference/genesys/formatask.py b/src/zeroband/inference/genesys/formatask.py new file mode 100644 index 00000000..cb667de8 --- /dev/null +++ b/src/zeroband/inference/genesys/formatask.py @@ -0,0 +1,99 @@ +import difflib +import re +from typing import Dict + + +def detect_format_signature(text: str) -> tuple: + signature = [] + current_text = text + + # Check for markdown emphasis (changed to if statements for layering) + if current_text.startswith("***") and current_text.endswith("***") and len(current_text) > 6: + signature.append("triple_asterisk") + current_text = current_text[3:-3] + elif current_text.startswith("**") and current_text.endswith("**") and len(current_text) > 4: + signature.append("bold") + current_text = current_text[2:-2] + elif current_text.startswith("*") and current_text.endswith("*") and len(current_text) > 2: + signature.append("italic") + current_text = current_text[1:-1] + + # Check for code blocks + if current_text.startswith("```\n") and current_text.endswith("\n```"): + signature.append("code_block") + current_text = current_text[4:-4] + + # Check for wrapper characters (changed to if statements for layering) + if current_text.startswith('"') and current_text.endswith('"') and len(current_text) > 2: + signature.append("quotes") + current_text = current_text[1:-1] + elif current_text.startswith("[") and current_text.endswith("]") and len(current_text) > 2: + signature.append("brackets") + current_text = current_text[1:-1] + elif current_text.startswith("(") and current_text.endswith(")") and len(current_text) > 2: + signature.append("parentheses") + current_text = current_text[1:-1] + elif current_text.startswith("-") and current_text.endswith("-") and len(current_text) > 2: + signature.append("dashes") + current_text = current_text[1:-1] + elif current_text.startswith("<") and current_text.endswith(">") and len(current_text) > 2: + signature.append("angle_brackets") + current_text = current_text[1:-1] + + # Check for uppercase (this can layer with other formats) + if current_text.isupper() and len(current_text) > 0: + signature.append("uppercase") + + return tuple(signature) if signature else ("none",) + + +def extract_and_score(text: str, tag_name: str, ground_truth: str) -> float: + pattern = f"<{tag_name}>(.*?)" + match = re.search(pattern, text, re.DOTALL) + + if not match: + return 0 + + extracted = match.group(1) + + # Check format compliance + gt_sig = detect_format_signature(ground_truth) + ext_sig = detect_format_signature(extracted) + + # Calculate similarity + similarity = difflib.SequenceMatcher(None, extracted, ground_truth).ratio() + + # If format is wrong but content is extractable, give 0.1x reward + if gt_sig != ("none",) and gt_sig != ext_sig: + return similarity * 0.1 + + return similarity + + +def compute_reward(completion: str, verification_info: Dict, tag_name: str = "extracted_formatted") -> float: + try: + # Skip thinking section + text = completion + think_end = completion.find("") + if think_end != -1: + text = completion[think_end + len("") :] + + # Check for dual case first + if "ground_truth1" in verification_info and "ground_truth2" in verification_info: + score1 = extract_and_score(text, "extracted_formatted1", verification_info["ground_truth1"]) + score2 = extract_and_score(text, "extracted_formatted2", verification_info["ground_truth2"]) + + if score1 == 0 or score2 == 0: + return 0 + + return (score1 + score2) / 2.0 + + # Single case + ground_truth = verification_info.get("ground_truth") + if not ground_truth: + return 0 + + return extract_and_score(text, tag_name, ground_truth) + + except Exception: + return 0