From 53448dd9e6c8868e6349b06dca29da80ae6b7187 Mon Sep 17 00:00:00 2001 From: carlos Date: Sun, 2 Feb 2025 19:28:02 -0500 Subject: [PATCH] Refactor/cache create text dataset (#309) * Refactor create_text_dataset to cache intermediate results, easing memory reqs * Bump version to v3.0.6 --- swebench/__init__.py | 2 +- .../make_datasets/create_instance.py | 203 ++++++++++-------- .../make_datasets/create_text_dataset.py | 192 ++++++++++------- 3 files changed, 229 insertions(+), 168 deletions(-) diff --git a/swebench/__init__.py b/swebench/__init__.py index 1ed61a0c..595f4a53 100644 --- a/swebench/__init__.py +++ b/swebench/__init__.py @@ -1,4 +1,4 @@ -__version__ = "3.0.5" +__version__ = "3.0.6" from swebench.collect.build_dataset import main as build_dataset from swebench.collect.get_tasks_pipeline import main as get_tasks_pipeline diff --git a/swebench/inference/make_datasets/create_instance.py b/swebench/inference/make_datasets/create_instance.py index b5109481..130cc5c6 100644 --- a/swebench/inference/make_datasets/create_instance.py +++ b/swebench/inference/make_datasets/create_instance.py @@ -335,7 +335,7 @@ def get_oracle_filenames(instance): def add_text_inputs( - input_instances, + instances, retrieval_file, k, prompt_style, @@ -343,93 +343,122 @@ def add_text_inputs( max_context_len=None, tokenizer_name=None, verbose=False, -): - """Adds text inputs context for prediction in-place. - + progress_file=None, +) -> None: + """Process instances and save results to progress file. + Args: - - input_instances: dictionary with unprocessed input instances. - - retrieval_file: if using retrieval method for file_contents, specify retrieval_file to add retrieval results - - k: if using retrieval, specifies the maximum number of files to included within context - - prompt_style: specify the function to generate instructions and prompt provided an instance (from PROMPT_FUNCTIONS) + - instances: dictionary with unprocessed input instances + - retrieval_file: if using retrieval method for file_contents, specify retrieval_file + - k: if using retrieval, specifies the maximum number of files to include + - prompt_style: specify the function to generate instructions and prompt - file_source: where to collect file_contents (e.g. oracle or bm25) - verbose: set ContextManager verbose to True + - progress_file: required, path to save processed instances """ - if max_context_len is not None: - assert ( - tokenizer_name is not None - ), "Must specify tokenizer_name if using max_context_len" - tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name] - input_instances_copy = deepcopy(input_instances) - if file_source in {"bm25"}: - add_retrieval_results(input_instances_copy, retrieval_file, k, file_source) - orig_dir = os.getcwd() - with TemporaryDirectory( - dir="/scratch" if os.path.exists("/scratch") else "/tmp" - ) as root_dir: - for instance_id, instance in tqdm( - input_instances_copy.items(), - total=len(input_instances_copy), - desc="Adding text inputs", - ): - try: - with AutoContextManager( - instance, root_dir, verbose=verbose - ) as cm: - readmes = cm.get_readme_files() - instance["readmes"] = ingest_files(readmes) - if max_context_len is not None: - instance["file_contents"] = dict() - base_text_inputs = PROMPT_FUNCTIONS[prompt_style](instance) - base_text_input_length = len( - tokenizer_func(base_text_inputs, tokenizer) - ) - if file_source in {"oracle"}: - instance["file_contents"] = ingest_files( - get_oracle_filenames(instance) - ) - elif file_source in {"bm25"}: - instance["file_contents"] = ingest_files( - [x["docid"] for x in instance["hits"]] - ) - elif file_source in {"all"}: - instance["file_contents"] = ingest_directory_contents( - cm.repo_path - ) - elif file_source in {"none"}: - instance["file_contents"] = dict() - else: - raise ValueError(f"Invalid file source {file_source}") - if max_context_len is not None: - cur_input_len = base_text_input_length - include_files = list() - for filename in [x["docid"] for x in instance["hits"]]: - content = make_code_text( - {filename: instance["file_contents"][filename]} - ) - if tokenizer_name in {"llama"}: - tokens = tokenizer_func("\n" + content, tokenizer) - idx = tokens.index(13) - assert ( - idx <= 2 - ), "Expected newline token id (13) to be one of the first three tokens" - tokens = tokens[idx + 1 :] # remove newline tokens - else: - tokens = tokenizer_func(content, tokenizer) - if cur_input_len + len(tokens) < max_context_len: - include_files.append(filename) - cur_input_len += len(tokens) - instance["file_contents"] = { - filename: instance["file_contents"][filename] - for filename in include_files - } - input_instances[instance_id]["text_inputs"] = PROMPT_FUNCTIONS[ - prompt_style - ](instance) - except Exception as e: - print(f"Failed on instance {instance_id}", e) - traceback.print_exc() - input_instances[instance_id]["text_inputs"] = None - finally: - # if AutoContextManager fails to exit properly future exits will return the wrong directory - os.chdir(orig_dir) - os.chdir(orig_dir) + assert progress_file is not None, "progress_file is required" + + # Create progress file directory if it doesn't exist + progress_path = Path(progress_file) + progress_path.parent.mkdir(parents=True, exist_ok=True) + + # Load already processed instances + processed_ids = set() + file_exists = os.path.exists(progress_file) + + if file_exists: + with open(progress_file) as f: + for line in f: + instance = json.loads(line) + processed_ids.add(instance['instance_id']) + logger.info(f"Found {len(processed_ids)} already processed instances") + progress_file_handle = open(progress_file, 'a') + else: + progress_file_handle = open(progress_file, 'w') + + try: + if max_context_len is not None: + assert tokenizer_name is not None, "Must specify tokenizer_name if using max_context_len" + tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name] + + # Add retrieval results if needed + if file_source in {"bm25"}: + instances = deepcopy(instances) + add_retrieval_results(instances, retrieval_file, k, file_source) + + # Filter out already processed instances + instances_to_process = {k: v for k, v in instances.items() if k not in processed_ids} + logger.info(f"Processing {len(instances_to_process)} instances") + + orig_dir = os.getcwd() + with TemporaryDirectory(dir="/scratch" if os.path.exists("/scratch") else "/tmp") as root_dir: + for instance_id, instance in tqdm( + instances_to_process.items(), + total=len(instances_to_process), + desc="Processing instances" + ): + try: + with AutoContextManager(instance, root_dir, verbose=verbose) as cm: + # Process instance + processed_instance = deepcopy(instance) + + # Add readmes + readmes = cm.get_readme_files() + processed_instance["readmes"] = ingest_files(readmes) + + # Handle file contents based on configuration + if max_context_len is not None: + processed_instance["file_contents"] = dict() + base_text_inputs = PROMPT_FUNCTIONS[prompt_style](processed_instance) + base_text_input_length = len(tokenizer_func(base_text_inputs, tokenizer)) + + if file_source == "oracle": + processed_instance["file_contents"] = ingest_files(get_oracle_filenames(processed_instance)) + elif file_source == "bm25": + processed_instance["file_contents"] = ingest_files([x["docid"] for x in processed_instance["hits"]]) + elif file_source == "all": + processed_instance["file_contents"] = ingest_directory_contents(cm.repo_path) + elif file_source == "none": + processed_instance["file_contents"] = dict() + else: + raise ValueError(f"Invalid file source {file_source}") + + # Handle context length limits + if max_context_len is not None: + cur_input_len = base_text_input_length + include_files = [] + for filename in [x["docid"] for x in processed_instance["hits"]]: + content = make_code_text({filename: processed_instance["file_contents"][filename]}) + if tokenizer_name == "llama": + tokens = tokenizer_func("\n" + content, tokenizer) + idx = tokens.index(13) + tokens = tokens[idx + 1:] + else: + tokens = tokenizer_func(content, tokenizer) + if cur_input_len + len(tokens) < max_context_len: + include_files.append(filename) + cur_input_len += len(tokens) + processed_instance["file_contents"] = { + filename: processed_instance["file_contents"][filename] + for filename in include_files + } + + # Generate final text inputs + processed_instance["text_inputs"] = PROMPT_FUNCTIONS[prompt_style](processed_instance) + + # Save to progress file + progress_file_handle.write(json.dumps(processed_instance) + '\n') + progress_file_handle.flush() + + except Exception as e: + print(f"Failed on instance {instance_id}", e) + traceback.print_exc() + # Save failed instance + failed_instance = {**instance, 'text_inputs': None} + progress_file_handle.write(json.dumps(failed_instance) + '\n') + progress_file_handle.flush() + finally: + os.chdir(orig_dir) + os.chdir(orig_dir) + finally: + progress_file_handle.close() diff --git a/swebench/inference/make_datasets/create_text_dataset.py b/swebench/inference/make_datasets/create_text_dataset.py index a49c48ea..2ccd8d16 100755 --- a/swebench/inference/make_datasets/create_text_dataset.py +++ b/swebench/inference/make_datasets/create_text_dataset.py @@ -66,6 +66,37 @@ def extract_fields(instance): return {**instance, "text": text_inputs, "patch": patch} +def validate_arguments(push_to_hub_user, output_dir, max_context_len, tokenizer_name, file_source, k): + """Validate command line arguments and environment setup.""" + if push_to_hub_user is not None: + hub_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) + assert hub_token is not None, "Must provide HUGGING_FACE_HUB_TOKEN to push to the Hub" + assert output_dir is None, "Cannot provide output_dir if pushing to the Hub" + if max_context_len is not None: + assert tokenizer_name is not None + if push_to_hub_user is None and not Path(output_dir).exists(): + Path(output_dir).mkdir(parents=True) + if max_context_len is not None: + assert file_source not in {"all", "oracle"}, "Cannot use max_context_len with oracle or all file sources" + assert tokenizer_name is not None, "Must provide tokenizer_name if max_context_len is not None" + if k is not None: + assert file_source not in {"all", "oracle"}, "Cannot use max_context_len with oracle or all file sources" + return hub_token if push_to_hub_user is not None else None + + +def construct_output_filename(dataset_name, prompt_style, file_source, k, max_context_len, tokenizer_name): + """Construct the output filename based on parameters.""" + if dataset_name.startswith("princeton-nlp"): + dataset_name = dataset_name.split("/")[-1] + dataset_name = dataset_name.replace("/", "__") + output_file = f"{dataset_name}__{prompt_style}__fs-{file_source}" + if k is not None: + output_file += f"__k-{k}" + if max_context_len is not None: + output_file += f"__mcc-{max_context_len}-{tokenizer_name}" + return output_file + + def main( dataset_name_or_path, splits, @@ -79,99 +110,100 @@ def main( tokenizer_name, push_to_hub_user, ): - if push_to_hub_user is not None: - hub_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None) - assert hub_token is not None, "Must provide HUGGING_FACE_HUB_TOKEN to push to the Hub" - assert output_dir is None, "Cannot provide output_dir if pushing to the Hub" - if max_context_len is not None: - assert tokenizer_name is not None - if push_to_hub_user is None and not Path(output_dir).exists(): - Path(output_dir).mkdir(parents=True) - output_file = f"SWE-bench__{prompt_style}__fs-{file_source}" - if k is not None: - assert file_source not in { - "all", - "oracle", - }, "Cannot use max_context_len with oracle or all file sources" - output_file += f"__k-{k}" - if max_context_len is not None: - assert file_source not in { - "all", - "oracle", - }, "Cannot use max_context_len with oracle or all file sources" - assert ( - tokenizer_name is not None - ), "Must provide tokenizer_name if max_context_len is not None" - output_file += f"__mcc-{max_context_len}-{tokenizer_name}" + # Validate arguments and setup + hub_token = validate_arguments(push_to_hub_user, output_dir, max_context_len, tokenizer_name, file_source, k) + output_file = construct_output_filename(dataset_name_or_path, prompt_style, file_source, k, max_context_len, tokenizer_name) + output_file = Path(output_dir, output_file) if push_to_hub_user is None: - output_file = Path(output_dir, output_file) if output_file.exists(): - logger.info(f"{output_file.absolute().as_posix()} already exists. Aborting") - return - output_file = str(output_file) - if Path(dataset_name_or_path).exists(): - dataset = load_from_disk(dataset_name_or_path) - else: - dataset = load_dataset(dataset_name_or_path) + existing_dataset = load_from_disk(output_file) + # if requested splits are in existing dataset, abort + for split in splits: + if split in existing_dataset: + logger.info(f"{output_file.absolute().as_posix()} already exists for split {split}. Aborting") + return + del existing_dataset # don't store in memory - split_instances = dict() + # Load dataset + dataset = load_from_disk(dataset_name_or_path) if Path(dataset_name_or_path).exists() else load_dataset(dataset_name_or_path) logger.info(f'Found {set(dataset.keys())} splits') if set(splits) - set(dataset.keys()) != set(): raise ValueError(f"Unknown splits {set(splits) - set(dataset.keys())}") + + # Define columns for final dataset + columns = [ + "instance_id", "text", "repo", "base_commit", "problem_statement", + "hints_text", "created_at", "patch", "test_patch", "version", + "FAIL_TO_PASS", "PASS_TO_PASS", "environment_setup_commit", + ] + + # Process each split + split_data = {} + progress_files = {} for split in splits: - split_instances[split] = {x["instance_id"]: x for x in dataset[split]} + logger.info(f"Processing {split} split") + split_instances = {x["instance_id"]: x for x in dataset[split]} + progress_file = f"{output_file}.{split}.progress.jsonl" + progress_files[split] = progress_file + # Process instances and save to progress file add_text_inputs( - split_instances[split], - retrieval_file, - k, - prompt_style, - file_source, + split_instances, + retrieval_file=retrieval_file, + k=k, + prompt_style=prompt_style, + file_source=file_source, max_context_len=max_context_len, tokenizer_name=tokenizer_name, + progress_file=progress_file ) - columns = [ - "instance_id", - "text", - "repo", - "base_commit", - "problem_statement", - "hints_text", - "created_at", - "patch", - "test_patch", - "version", - "FAIL_TO_PASS", - "PASS_TO_PASS", - "environment_setup_commit", - ] - split_data = dict() - for split in split_instances: - split_data[split] = {key: list() for key in columns} - for instance in tqdm( - split_instances[split].values(), total=len(split_instances[split]), desc=f'Processing {split} instances', - ): - datum = extract_fields(instance) - if datum is None: - continue - for key in columns: - split_data[split][key].append(datum[key] if key in datum else "") - logger.info(f"Found {len(split_data[split]['instance_id'])} {split} ids") - split_data[split] = Dataset.from_dict(split_data[split]) - dataset = DatasetDict(split_data) - if validation_ratio > 0 and "train" in dataset: - train_val = dataset["train"].train_test_split( - test_size=validation_ratio, - seed=42, - ) - dataset["train"] = train_val["train"] - dataset["validation"] = train_val["test"] - for split in dataset: - logger.info(f"Found {len(dataset[split])} {split} instances") + + logger.info("Creating final dataset") + # Create final dataset + if output_file.exists(): + final_dataset = load_from_disk(output_file) + else: + final_dataset = DatasetDict() + for split in splits: + split_data = {key: [] for key in columns} + valid_instance_ids = set(dataset[split]["instance_id"]) + invalid_instances = [] + + with open(progress_files[split]) as f: + for line in f: + datum = json.loads(line) + if datum["instance_id"] not in valid_instance_ids: + invalid_instances.append(datum["instance_id"]) + continue + for key in columns: + split_data[key].append(datum.get(key, "")) + + if invalid_instances: + logger.warning(f"Found {len(invalid_instances)} instances in progress file that are not in the {split} dataset: {invalid_instances}. These will be removed from the final dataset.") + + final_dataset[split] = Dataset.from_dict(split_data) + + # Handle validation split + if validation_ratio > 0 and "train" in final_dataset: + train_val = final_dataset["train"].train_test_split(test_size=validation_ratio, seed=42) + final_dataset["train"] = train_val["train"] + final_dataset["validation"] = train_val["test"] + + # Log final dataset sizes + for split in final_dataset: + logger.info(f"Found {len(final_dataset[split])} {split} instances") + + # Save dataset if push_to_hub_user is not None: - dataset.push_to_hub(f'{push_to_hub_user}/{output_file}', use_auth_token=hub_token) + final_dataset.push_to_hub(f'{push_to_hub_user}/{output_file.name}', use_auth_token=hub_token) else: - dataset.save_to_disk(output_file) - logger.info(f"Finsihed saving to {output_file}") + final_dataset.save_to_disk(output_file) + + # Cleanup progress files + for progress_file in progress_files.values(): + if os.path.exists(progress_file): + os.remove(progress_file) + + logger.info(f"Finished saving to {output_file}") if __name__ == "__main__":