Skip to content

Commit

Permalink
Refactor/cache create text dataset (#309)
Browse files Browse the repository at this point in the history
* Refactor create_text_dataset to cache intermediate results, easing memory reqs

* Bump version to v3.0.6
  • Loading branch information
carlosejimenez authored Feb 3, 2025
1 parent e5c2dc9 commit 53448dd
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 168 deletions.
2 changes: 1 addition & 1 deletion swebench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "3.0.5"
__version__ = "3.0.6"

from swebench.collect.build_dataset import main as build_dataset
from swebench.collect.get_tasks_pipeline import main as get_tasks_pipeline
Expand Down
203 changes: 116 additions & 87 deletions swebench/inference/make_datasets/create_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,101 +335,130 @@ def get_oracle_filenames(instance):


def add_text_inputs(
input_instances,
instances,
retrieval_file,
k,
prompt_style,
file_source,
max_context_len=None,
tokenizer_name=None,
verbose=False,
):
"""Adds text inputs context for prediction in-place.
progress_file=None,
) -> None:
"""Process instances and save results to progress file.
Args:
- input_instances: dictionary with unprocessed input instances.
- retrieval_file: if using retrieval method for file_contents, specify retrieval_file to add retrieval results
- k: if using retrieval, specifies the maximum number of files to included within context
- prompt_style: specify the function to generate instructions and prompt provided an instance (from PROMPT_FUNCTIONS)
- instances: dictionary with unprocessed input instances
- retrieval_file: if using retrieval method for file_contents, specify retrieval_file
- k: if using retrieval, specifies the maximum number of files to include
- prompt_style: specify the function to generate instructions and prompt
- file_source: where to collect file_contents (e.g. oracle or bm25)
- verbose: set ContextManager verbose to True
- progress_file: required, path to save processed instances
"""
if max_context_len is not None:
assert (
tokenizer_name is not None
), "Must specify tokenizer_name if using max_context_len"
tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name]
input_instances_copy = deepcopy(input_instances)
if file_source in {"bm25"}:
add_retrieval_results(input_instances_copy, retrieval_file, k, file_source)
orig_dir = os.getcwd()
with TemporaryDirectory(
dir="/scratch" if os.path.exists("/scratch") else "/tmp"
) as root_dir:
for instance_id, instance in tqdm(
input_instances_copy.items(),
total=len(input_instances_copy),
desc="Adding text inputs",
):
try:
with AutoContextManager(
instance, root_dir, verbose=verbose
) as cm:
readmes = cm.get_readme_files()
instance["readmes"] = ingest_files(readmes)
if max_context_len is not None:
instance["file_contents"] = dict()
base_text_inputs = PROMPT_FUNCTIONS[prompt_style](instance)
base_text_input_length = len(
tokenizer_func(base_text_inputs, tokenizer)
)
if file_source in {"oracle"}:
instance["file_contents"] = ingest_files(
get_oracle_filenames(instance)
)
elif file_source in {"bm25"}:
instance["file_contents"] = ingest_files(
[x["docid"] for x in instance["hits"]]
)
elif file_source in {"all"}:
instance["file_contents"] = ingest_directory_contents(
cm.repo_path
)
elif file_source in {"none"}:
instance["file_contents"] = dict()
else:
raise ValueError(f"Invalid file source {file_source}")
if max_context_len is not None:
cur_input_len = base_text_input_length
include_files = list()
for filename in [x["docid"] for x in instance["hits"]]:
content = make_code_text(
{filename: instance["file_contents"][filename]}
)
if tokenizer_name in {"llama"}:
tokens = tokenizer_func("\n" + content, tokenizer)
idx = tokens.index(13)
assert (
idx <= 2
), "Expected newline token id (13) to be one of the first three tokens"
tokens = tokens[idx + 1 :] # remove newline tokens
else:
tokens = tokenizer_func(content, tokenizer)
if cur_input_len + len(tokens) < max_context_len:
include_files.append(filename)
cur_input_len += len(tokens)
instance["file_contents"] = {
filename: instance["file_contents"][filename]
for filename in include_files
}
input_instances[instance_id]["text_inputs"] = PROMPT_FUNCTIONS[
prompt_style
](instance)
except Exception as e:
print(f"Failed on instance {instance_id}", e)
traceback.print_exc()
input_instances[instance_id]["text_inputs"] = None
finally:
# if AutoContextManager fails to exit properly future exits will return the wrong directory
os.chdir(orig_dir)
os.chdir(orig_dir)
assert progress_file is not None, "progress_file is required"

# Create progress file directory if it doesn't exist
progress_path = Path(progress_file)
progress_path.parent.mkdir(parents=True, exist_ok=True)

# Load already processed instances
processed_ids = set()
file_exists = os.path.exists(progress_file)

if file_exists:
with open(progress_file) as f:
for line in f:
instance = json.loads(line)
processed_ids.add(instance['instance_id'])
logger.info(f"Found {len(processed_ids)} already processed instances")
progress_file_handle = open(progress_file, 'a')
else:
progress_file_handle = open(progress_file, 'w')

try:
if max_context_len is not None:
assert tokenizer_name is not None, "Must specify tokenizer_name if using max_context_len"
tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name]

# Add retrieval results if needed
if file_source in {"bm25"}:
instances = deepcopy(instances)
add_retrieval_results(instances, retrieval_file, k, file_source)

# Filter out already processed instances
instances_to_process = {k: v for k, v in instances.items() if k not in processed_ids}
logger.info(f"Processing {len(instances_to_process)} instances")

orig_dir = os.getcwd()
with TemporaryDirectory(dir="/scratch" if os.path.exists("/scratch") else "/tmp") as root_dir:
for instance_id, instance in tqdm(
instances_to_process.items(),
total=len(instances_to_process),
desc="Processing instances"
):
try:
with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
# Process instance
processed_instance = deepcopy(instance)

# Add readmes
readmes = cm.get_readme_files()
processed_instance["readmes"] = ingest_files(readmes)

# Handle file contents based on configuration
if max_context_len is not None:
processed_instance["file_contents"] = dict()
base_text_inputs = PROMPT_FUNCTIONS[prompt_style](processed_instance)
base_text_input_length = len(tokenizer_func(base_text_inputs, tokenizer))

if file_source == "oracle":
processed_instance["file_contents"] = ingest_files(get_oracle_filenames(processed_instance))
elif file_source == "bm25":
processed_instance["file_contents"] = ingest_files([x["docid"] for x in processed_instance["hits"]])
elif file_source == "all":
processed_instance["file_contents"] = ingest_directory_contents(cm.repo_path)
elif file_source == "none":
processed_instance["file_contents"] = dict()
else:
raise ValueError(f"Invalid file source {file_source}")

# Handle context length limits
if max_context_len is not None:
cur_input_len = base_text_input_length
include_files = []
for filename in [x["docid"] for x in processed_instance["hits"]]:
content = make_code_text({filename: processed_instance["file_contents"][filename]})
if tokenizer_name == "llama":
tokens = tokenizer_func("\n" + content, tokenizer)
idx = tokens.index(13)
tokens = tokens[idx + 1:]
else:
tokens = tokenizer_func(content, tokenizer)
if cur_input_len + len(tokens) < max_context_len:
include_files.append(filename)
cur_input_len += len(tokens)
processed_instance["file_contents"] = {
filename: processed_instance["file_contents"][filename]
for filename in include_files
}

# Generate final text inputs
processed_instance["text_inputs"] = PROMPT_FUNCTIONS[prompt_style](processed_instance)

# Save to progress file
progress_file_handle.write(json.dumps(processed_instance) + '\n')
progress_file_handle.flush()

except Exception as e:
print(f"Failed on instance {instance_id}", e)
traceback.print_exc()
# Save failed instance
failed_instance = {**instance, 'text_inputs': None}
progress_file_handle.write(json.dumps(failed_instance) + '\n')
progress_file_handle.flush()
finally:
os.chdir(orig_dir)
os.chdir(orig_dir)
finally:
progress_file_handle.close()
Loading

0 comments on commit 53448dd

Please sign in to comment.