From 112fbfe1c74e50f8581fdce995406d6d0a86141b Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Wed, 4 Jun 2025 08:06:55 -0700 Subject: [PATCH 01/10] formatask init --- src/zeroband/inference/genesys/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/zeroband/inference/genesys/__init__.py b/src/zeroband/inference/genesys/__init__.py index ee7e955a..f40a3c58 100644 --- a/src/zeroband/inference/genesys/__init__.py +++ b/src/zeroband/inference/genesys/__init__.py @@ -9,6 +9,8 @@ from zeroband.inference.genesys.reasoning_gym import verify_reasoning_gym from zeroband.inference.genesys.reverse_text import reverse_text from zeroband.inference.genesys.unscramble_sentence import compute_reward as compute_unscramble_reward +from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward + TaskType = Literal[ "verifiable_math", @@ -20,6 +22,7 @@ "ascii_tree_formatting", "pydantic_adherance", "complex_json_output", + "formatask", ] @@ -40,4 +43,6 @@ def get_reward_function(task_type: TaskType) -> Callable[[str, dict], float]: "ascii_tree_formatting": compute_ascii_tree_reward, "pydantic_adherance": validate_pydantic_json, "complex_json_output": verify_complex_json_formatting, + "formatask": compute_formatask_reward, + } From 3d5c36a1f680e568f929f81c52af673520df78b1 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Wed, 4 Jun 2025 08:07:47 -0700 Subject: [PATCH 02/10] add task def --- src/zeroband/inference/genesys/formatask.py | 50 +++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/zeroband/inference/genesys/formatask.py diff --git a/src/zeroband/inference/genesys/formatask.py b/src/zeroband/inference/genesys/formatask.py new file mode 100644 index 00000000..9da55651 --- /dev/null +++ b/src/zeroband/inference/genesys/formatask.py @@ -0,0 +1,50 @@ +import difflib +import re +from typing import Dict + + +def compute_reward( + completion: str, + verification_info: Dict, + tag_name: str = "extracted_formatted" +): + """ + Generic difflib-based reward computation for tasks expecting extracted content in XML tags. + + Args: + completion: The model's completion text + verification_info: Dictionary containing ground truth + tag_name: XML tag name to extract content from + + Returns: + Float reward between 0 and 1 + """ + # Extract answer from specified tag + tag_pattern = f"<{tag_name}>(.*?)" + if f"<{tag_name}>" not in completion: + return 0 + + answer_match = re.search(tag_pattern, completion, re.DOTALL) + if not answer_match: + return 0 + + # Get ground truth + ground_truth = verification_info.get("ground_truth") + if not ground_truth: + return 0 + + try: + # Clean and split both into lines + answer_lines = answer_match.group(1).strip().split("\n") + truth_lines = ground_truth.strip().split("\n") + + # Use difflib to compare line sequences + matcher = difflib.SequenceMatcher(None, answer_lines, truth_lines) + + # Calculate similarity ratio + reward = matcher.ratio() + + return reward + + except Exception: + return 0 \ No newline at end of file From 3fe5758887fb45cc30584004e54c9040f4f61177 Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Wed, 4 Jun 2025 08:10:49 -0700 Subject: [PATCH 03/10] formatask configs --- configs/inference/formatask.toml | 11 +++++++++++ configs/training/formatask.toml | 23 +++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 configs/inference/formatask.toml create mode 100644 configs/training/formatask.toml diff --git a/configs/inference/formatask.toml b/configs/inference/formatask.toml new file mode 100644 index 00000000..730840d1 --- /dev/null +++ b/configs/inference/formatask.toml @@ -0,0 +1,11 @@ +model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" +dataset = "kalomaze/general-formatask-it1-7k" +batch_size = 256 +dp = 4 +rollout_path = "outputs" +output_path = "data_rollout" +max_model_len = 2048 + +[sampling] +temperature = 1.0 +n = 16 \ No newline at end of file diff --git a/configs/training/formatask.toml b/configs/training/formatask.toml new file mode 100644 index 00000000..3761dc88 --- /dev/null +++ b/configs/training/formatask.toml @@ -0,0 +1,23 @@ +model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" +project = "formatask-debug" + +[train] +micro_bs = 2 # change to 8 for H200 +reshard_after_forward = true + +[optim] +batch_size = 64 +warmup_steps = 20 +total_steps = 100000000000000 +step_per_rollout = 4 +grad_norm_clip = 0.0001 + +[optim.optim] +lr = 5e-6 + +[data] +path = "data_rollout" +seq_length = 2048 + +[ckpt] +rollout_path = "outputs" \ No newline at end of file From 88230e84e0eadc0a24b679c49acae96b28f09d4b Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Wed, 4 Jun 2025 08:14:22 -0700 Subject: [PATCH 04/10] fix formatting?? --- src/zeroband/inference/genesys/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/zeroband/inference/genesys/__init__.py b/src/zeroband/inference/genesys/__init__.py index f40a3c58..c741ec36 100644 --- a/src/zeroband/inference/genesys/__init__.py +++ b/src/zeroband/inference/genesys/__init__.py @@ -11,7 +11,6 @@ from zeroband.inference.genesys.unscramble_sentence import compute_reward as compute_unscramble_reward from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward - TaskType = Literal[ "verifiable_math", "prime_rl_code", @@ -44,5 +43,4 @@ def get_reward_function(task_type: TaskType) -> Callable[[str, dict], float]: "pydantic_adherance": validate_pydantic_json, "complex_json_output": verify_complex_json_formatting, "formatask": compute_formatask_reward, - } From 5bd313cece9c40fcf6d5d8e94d7c47f961cf488a Mon Sep 17 00:00:00 2001 From: kalomaze Date: Fri, 6 Jun 2025 05:57:56 +0000 Subject: [PATCH 05/10] fix + overhaul difflib reward --- src/zeroband/inference/genesys/formatask.py | 92 ++++++++++++++++----- 1 file changed, 73 insertions(+), 19 deletions(-) diff --git a/src/zeroband/inference/genesys/formatask.py b/src/zeroband/inference/genesys/formatask.py index 9da55651..8fce4f1a 100644 --- a/src/zeroband/inference/genesys/formatask.py +++ b/src/zeroband/inference/genesys/formatask.py @@ -3,48 +3,102 @@ from typing import Dict -def compute_reward( - completion: str, - verification_info: Dict, - tag_name: str = "extracted_formatted" -): +def detect_format_type(text: str) -> str: + """ + Detect the formatting type used in the text. + Returns the format type or 'none' if no specific format detected. + """ + # Check for various formatting patterns + if text.startswith("**") and text.endswith("**") and len(text) > 4: + if text.startswith("***") and text.endswith("***"): + return "triple_asterisk" + return "bold" + elif text.startswith("*") and text.endswith("*") and len(text) > 2: + return "italic" + elif text.isupper() and len(text) > 1: + return "uppercase" + elif text.startswith("```\n") and text.endswith("\n```"): + return "code_block" + elif text.startswith('"') and text.endswith('"') and len(text) > 2: + return "quotes" + elif text.startswith("[") and text.endswith("]") and len(text) > 2: + return "brackets" + elif text.startswith("(") and text.endswith(")") and len(text) > 2: + return "parentheses" + elif text.startswith("-") and text.endswith("-") and len(text) > 2: + return "dashes" + elif text.startswith("<") and text.endswith(">") and len(text) > 2: + return "angle_brackets" + + return "none" + + +def compute_reward(completion: str, verification_info: Dict, tag_name: str = "extracted_formatted"): """ Generic difflib-based reward computation for tasks expecting extracted content in XML tags. - + Uses normalized text comparison to avoid harsh penalties for formatting differences. + Only looks for XML tags AFTER if exists. + Returns 0 if format specification is not followed exactly. + Args: completion: The model's completion text verification_info: Dictionary containing ground truth tag_name: XML tag name to extract content from - + Returns: Float reward between 0 and 1 """ - # Extract answer from specified tag + # First, check if exists and skip thinking section if it does + search_text = completion + think_end = completion.find("") + if think_end != -1: + # If found, only search AFTER the thinking section + search_text = completion[think_end + len("") :] + + # Extract answer from specified tag in the search text tag_pattern = f"<{tag_name}>(.*?)" - if f"<{tag_name}>" not in completion: + if f"<{tag_name}>" not in search_text: return 0 - - answer_match = re.search(tag_pattern, completion, re.DOTALL) + + answer_match = re.search(tag_pattern, search_text, re.DOTALL) if not answer_match: return 0 + # Check if there's exactly one XML tag with the right name (for bonus) + xml_bonus = 0 + all_tags = re.findall(f"<{tag_name}>.*?", search_text, re.DOTALL) + if len(all_tags) == 1: + xml_bonus = 0.01 + # Get ground truth ground_truth = verification_info.get("ground_truth") if not ground_truth: - return 0 + return xml_bonus # Return bonus if no ground truth to compare against try: - # Clean and split both into lines - answer_lines = answer_match.group(1).strip().split("\n") - truth_lines = ground_truth.strip().split("\n") + # Extract content from both + extracted_text = answer_match.group(1) + ground_truth_text = ground_truth - # Use difflib to compare line sequences - matcher = difflib.SequenceMatcher(None, answer_lines, truth_lines) + # CHECK FORMAT COMPLIANCE FIRST + ground_truth_format = detect_format_type(ground_truth_text) + extracted_format = detect_format_type(extracted_text) + + # If ground truth has a specific format, extracted text must match that format + if ground_truth_format != "none" and ground_truth_format != extracted_format: + return 0 # Format specification not followed - return 0 + + # If format check passes, proceed with difflib comparison + # Use difflib on raw text for sequence comparison + matcher = difflib.SequenceMatcher(None, extracted_text, ground_truth_text) # Calculate similarity ratio reward = matcher.ratio() - return reward + # Add the XML tag bonus + reward += xml_bonus + + return min(reward, 1.0) # Cap at 1.0 to prevent bonus from pushing over 1 except Exception: - return 0 \ No newline at end of file + return xml_bonus # Return bonus even if comparison fails From a329bcf4822489510350b7fbdfeca053dc952ede Mon Sep 17 00:00:00 2001 From: kalomaze Date: Thu, 12 Jun 2025 02:35:11 +0000 Subject: [PATCH 06/10] complete reward function definition --- src/zeroband/inference/genesys/formatask.py | 167 ++++++++++---------- 1 file changed, 81 insertions(+), 86 deletions(-) diff --git a/src/zeroband/inference/genesys/formatask.py b/src/zeroband/inference/genesys/formatask.py index 8fce4f1a..cb667de8 100644 --- a/src/zeroband/inference/genesys/formatask.py +++ b/src/zeroband/inference/genesys/formatask.py @@ -3,102 +3,97 @@ from typing import Dict -def detect_format_type(text: str) -> str: - """ - Detect the formatting type used in the text. - Returns the format type or 'none' if no specific format detected. - """ - # Check for various formatting patterns - if text.startswith("**") and text.endswith("**") and len(text) > 4: - if text.startswith("***") and text.endswith("***"): - return "triple_asterisk" - return "bold" - elif text.startswith("*") and text.endswith("*") and len(text) > 2: - return "italic" - elif text.isupper() and len(text) > 1: - return "uppercase" - elif text.startswith("```\n") and text.endswith("\n```"): - return "code_block" - elif text.startswith('"') and text.endswith('"') and len(text) > 2: - return "quotes" - elif text.startswith("[") and text.endswith("]") and len(text) > 2: - return "brackets" - elif text.startswith("(") and text.endswith(")") and len(text) > 2: - return "parentheses" - elif text.startswith("-") and text.endswith("-") and len(text) > 2: - return "dashes" - elif text.startswith("<") and text.endswith(">") and len(text) > 2: - return "angle_brackets" - - return "none" - - -def compute_reward(completion: str, verification_info: Dict, tag_name: str = "extracted_formatted"): - """ - Generic difflib-based reward computation for tasks expecting extracted content in XML tags. - Uses normalized text comparison to avoid harsh penalties for formatting differences. - Only looks for XML tags AFTER if exists. - Returns 0 if format specification is not followed exactly. - - Args: - completion: The model's completion text - verification_info: Dictionary containing ground truth - tag_name: XML tag name to extract content from - - Returns: - Float reward between 0 and 1 - """ - # First, check if exists and skip thinking section if it does - search_text = completion - think_end = completion.find("") - if think_end != -1: - # If found, only search AFTER the thinking section - search_text = completion[think_end + len("") :] - - # Extract answer from specified tag in the search text - tag_pattern = f"<{tag_name}>(.*?)" - if f"<{tag_name}>" not in search_text: +def detect_format_signature(text: str) -> tuple: + signature = [] + current_text = text + + # Check for markdown emphasis (changed to if statements for layering) + if current_text.startswith("***") and current_text.endswith("***") and len(current_text) > 6: + signature.append("triple_asterisk") + current_text = current_text[3:-3] + elif current_text.startswith("**") and current_text.endswith("**") and len(current_text) > 4: + signature.append("bold") + current_text = current_text[2:-2] + elif current_text.startswith("*") and current_text.endswith("*") and len(current_text) > 2: + signature.append("italic") + current_text = current_text[1:-1] + + # Check for code blocks + if current_text.startswith("```\n") and current_text.endswith("\n```"): + signature.append("code_block") + current_text = current_text[4:-4] + + # Check for wrapper characters (changed to if statements for layering) + if current_text.startswith('"') and current_text.endswith('"') and len(current_text) > 2: + signature.append("quotes") + current_text = current_text[1:-1] + elif current_text.startswith("[") and current_text.endswith("]") and len(current_text) > 2: + signature.append("brackets") + current_text = current_text[1:-1] + elif current_text.startswith("(") and current_text.endswith(")") and len(current_text) > 2: + signature.append("parentheses") + current_text = current_text[1:-1] + elif current_text.startswith("-") and current_text.endswith("-") and len(current_text) > 2: + signature.append("dashes") + current_text = current_text[1:-1] + elif current_text.startswith("<") and current_text.endswith(">") and len(current_text) > 2: + signature.append("angle_brackets") + current_text = current_text[1:-1] + + # Check for uppercase (this can layer with other formats) + if current_text.isupper() and len(current_text) > 0: + signature.append("uppercase") + + return tuple(signature) if signature else ("none",) + + +def extract_and_score(text: str, tag_name: str, ground_truth: str) -> float: + pattern = f"<{tag_name}>(.*?)" + match = re.search(pattern, text, re.DOTALL) + + if not match: return 0 - answer_match = re.search(tag_pattern, search_text, re.DOTALL) - if not answer_match: - return 0 + extracted = match.group(1) - # Check if there's exactly one XML tag with the right name (for bonus) - xml_bonus = 0 - all_tags = re.findall(f"<{tag_name}>.*?", search_text, re.DOTALL) - if len(all_tags) == 1: - xml_bonus = 0.01 + # Check format compliance + gt_sig = detect_format_signature(ground_truth) + ext_sig = detect_format_signature(extracted) - # Get ground truth - ground_truth = verification_info.get("ground_truth") - if not ground_truth: - return xml_bonus # Return bonus if no ground truth to compare against + # Calculate similarity + similarity = difflib.SequenceMatcher(None, extracted, ground_truth).ratio() - try: - # Extract content from both - extracted_text = answer_match.group(1) - ground_truth_text = ground_truth + # If format is wrong but content is extractable, give 0.1x reward + if gt_sig != ("none",) and gt_sig != ext_sig: + return similarity * 0.1 + + return similarity - # CHECK FORMAT COMPLIANCE FIRST - ground_truth_format = detect_format_type(ground_truth_text) - extracted_format = detect_format_type(extracted_text) - # If ground truth has a specific format, extracted text must match that format - if ground_truth_format != "none" and ground_truth_format != extracted_format: - return 0 # Format specification not followed - return 0 +def compute_reward(completion: str, verification_info: Dict, tag_name: str = "extracted_formatted") -> float: + try: + # Skip thinking section + text = completion + think_end = completion.find("") + if think_end != -1: + text = completion[think_end + len("") :] + + # Check for dual case first + if "ground_truth1" in verification_info and "ground_truth2" in verification_info: + score1 = extract_and_score(text, "extracted_formatted1", verification_info["ground_truth1"]) + score2 = extract_and_score(text, "extracted_formatted2", verification_info["ground_truth2"]) - # If format check passes, proceed with difflib comparison - # Use difflib on raw text for sequence comparison - matcher = difflib.SequenceMatcher(None, extracted_text, ground_truth_text) + if score1 == 0 or score2 == 0: + return 0 - # Calculate similarity ratio - reward = matcher.ratio() + return (score1 + score2) / 2.0 - # Add the XML tag bonus - reward += xml_bonus + # Single case + ground_truth = verification_info.get("ground_truth") + if not ground_truth: + return 0 - return min(reward, 1.0) # Cap at 1.0 to prevent bonus from pushing over 1 + return extract_and_score(text, tag_name, ground_truth) except Exception: - return xml_bonus # Return bonus even if comparison fails + return 0 From f1bdcc6a3721ebc5faf41beb401e8fcc95d59b1f Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Thu, 12 Jun 2025 00:47:23 -0500 Subject: [PATCH 07/10] autoclean rollouts #1 Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> --- configs/training/formatask.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/training/formatask.toml b/configs/training/formatask.toml index 3761dc88..45e33001 100644 --- a/configs/training/formatask.toml +++ b/configs/training/formatask.toml @@ -20,4 +20,5 @@ path = "data_rollout" seq_length = 2048 [ckpt] -rollout_path = "outputs" \ No newline at end of file +rollout_path = "outputs" +clean_rollout_path = true \ No newline at end of file From c9668973a4ab207d94d10ea87384845ab2cc31af Mon Sep 17 00:00:00 2001 From: kalomaze <66376113+kalomaze@users.noreply.github.com> Date: Thu, 12 Jun 2025 00:47:35 -0500 Subject: [PATCH 08/10] autoclean rollouts #2 Co-authored-by: samsja <55492238+samsja@users.noreply.github.com> --- configs/inference/formatask.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/inference/formatask.toml b/configs/inference/formatask.toml index 730840d1..45d34834 100644 --- a/configs/inference/formatask.toml +++ b/configs/inference/formatask.toml @@ -3,6 +3,7 @@ dataset = "kalomaze/general-formatask-it1-7k" batch_size = 256 dp = 4 rollout_path = "outputs" +clean_output_path = true output_path = "data_rollout" max_model_len = 2048 From 1072f92a1b342257d8ff67486aea93608c50afb4 Mon Sep 17 00:00:00 2001 From: kalomaze Date: Thu, 12 Jun 2025 06:13:22 +0000 Subject: [PATCH 09/10] expanded multiepoch dataset --- configs/inference/formatask.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/inference/formatask.toml b/configs/inference/formatask.toml index 45d34834..7b324e89 100644 --- a/configs/inference/formatask.toml +++ b/configs/inference/formatask.toml @@ -1,5 +1,5 @@ model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" -dataset = "kalomaze/general-formatask-it1-7k" +dataset = "kalomaze/general-formatask-it2" batch_size = 256 dp = 4 rollout_path = "outputs" From 7c2fe3036a24c9914973df485069d55c3a4c5a57 Mon Sep 17 00:00:00 2001 From: kalomaze Date: Thu, 12 Jun 2025 06:21:15 +0000 Subject: [PATCH 10/10] ruff fix --- src/zeroband/inference/genesys/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zeroband/inference/genesys/__init__.py b/src/zeroband/inference/genesys/__init__.py index c741ec36..6aff91d5 100644 --- a/src/zeroband/inference/genesys/__init__.py +++ b/src/zeroband/inference/genesys/__init__.py @@ -4,12 +4,12 @@ from zeroband.inference.genesys.code import evaluate_code from zeroband.inference.genesys.code_output_prediction import verify_code_output_prediction from zeroband.inference.genesys.complex_json_output import verify_complex_json_formatting +from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward from zeroband.inference.genesys.math import compute_math_reward from zeroband.inference.genesys.pydantic_json_adherance import validate_pydantic_json from zeroband.inference.genesys.reasoning_gym import verify_reasoning_gym from zeroband.inference.genesys.reverse_text import reverse_text from zeroband.inference.genesys.unscramble_sentence import compute_reward as compute_unscramble_reward -from zeroband.inference.genesys.formatask import compute_reward as compute_formatask_reward TaskType = Literal[ "verifiable_math",