diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index 15500d3cc54..db10c8fc6b5 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -9,6 +9,7 @@ from .cache import cache from .csv import csv from .eval import eval +from .genbank import genbank from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json @@ -54,6 +55,7 @@ def _hash_python_lines(lines: list[str]) -> str: "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), "eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())), + "genbank": (genbank.__name__, _hash_python_lines(inspect.getsource(genbank).splitlines())), "lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())), } @@ -87,6 +89,9 @@ def _hash_python_lines(lines: list[str]) -> str: ".hdf5": ("hdf5", {}), ".h5": ("hdf5", {}), ".eval": ("eval", {}), + ".gb": ("genbank", {}), + ".gbk": ("genbank", {}), + ".genbank": ("genbank", {}), ".lance": ("lance", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) diff --git a/src/datasets/packaged_modules/genbank/__init__.py b/src/datasets/packaged_modules/genbank/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/genbank/genbank.py b/src/datasets/packaged_modules/genbank/genbank.py new file mode 100644 index 00000000000..e8337d85dc4 --- /dev/null +++ b/src/datasets/packaged_modules/genbank/genbank.py @@ -0,0 +1,480 @@ +"""GenBank file loader for biological sequence data with annotations. + +GenBank is a text-based format for storing nucleotide or protein sequences together with +their annotations and metadata, widely used in bioinformatics and maintained by NCBI. + +This implementation uses a lightweight pure Python state machine parser, +requiring zero external dependencies. +""" + +import bz2 +import gzip +import itertools +import json +import lzma +import re +from dataclasses import dataclass +from typing import Optional + +import pyarrow as pa + +import datasets +from datasets.builder import Key +from datasets.features.features import require_storage_cast +from datasets.table import table_cast + + +logger = datasets.utils.logging.get_logger(__name__) + + +# Conservative limit to stay well under Parquet's i32::MAX page limit (~2GB) +# Using 256MB as default since Parquet compresses data and we want headroom +DEFAULT_MAX_BATCH_BYTES = 256 * 1024 * 1024 # 256 MB + + +# Parser states for the GenBank state machine +class ParserState: + HEADER = "HEADER" + FEATURES = "FEATURES" + ORIGIN = "ORIGIN" + COMPLETE = "COMPLETE" + + +@dataclass +class GenBankConfig(datasets.BuilderConfig): + """BuilderConfig for GenBank files. + + Args: + features: Dataset features (optional, will be inferred if not provided). + batch_size: Maximum number of records per batch. Works in conjunction with + max_batch_bytes - a batch is flushed when either limit is reached. + max_batch_bytes: Maximum cumulative bytes per batch. This prevents Parquet + page size errors when dealing with very large sequences. Set to None + to disable byte-based batching. + columns: Subset of columns to include. Options: ["locus_name", "accession", + "version", "definition", "organism", "taxonomy", "keywords", "sequence", + "features", "length", "molecule_type"]. + parse_features: Whether to parse FEATURES section into structured JSON. + If False, stores raw text. + """ + + features: Optional[datasets.Features] = None + batch_size: int = 10000 + max_batch_bytes: Optional[int] = DEFAULT_MAX_BATCH_BYTES + columns: Optional[list[str]] = None + parse_features: bool = True + + def __post_init__(self): + super().__post_init__() + + +class GenBank(datasets.ArrowBasedBuilder): + """Dataset builder for GenBank files.""" + + BUILDER_CONFIG_CLASS = GenBankConfig + + # All supported GenBank extensions + EXTENSIONS: list[str] = [".gb", ".gbk", ".genbank"] + + # All available columns + ALL_COLUMNS: list[str] = [ + "locus_name", + "accession", + "version", + "definition", + "organism", + "taxonomy", + "keywords", + "sequence", + "features", + "length", + "molecule_type", + ] + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + """Generate splits from data files. + + The `data_files` kwarg in load_dataset() can be a str, List[str], + Dict[str,str], or Dict[str,List[str]]. + + If str or List[str], then the dataset returns only the 'train' split. + If dict, then keys should be from the `datasets.Split` enum. + """ + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + dl_manager.download_config.extract_on_the_fly = True + data_files = dl_manager.download_and_extract(self.config.data_files) + splits = [] + for split_name, files in data_files.items(): + if isinstance(files, str): + files = [files] + files = [dl_manager.iter_files(file) for file in files] + splits.append(datasets.SplitGenerator(name=split_name, gen_kwargs={"files": files})) + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + """Cast Arrow table to configured features schema.""" + if self.config.features is not None: + schema = self.config.features.arrow_schema + if all(not require_storage_cast(feature) for feature in self.config.features.values()): + pa_table = pa_table.cast(schema) + else: + pa_table = table_cast(pa_table, schema) + return pa_table + return pa_table + + def _open_file(self, filepath: str): + """Open file with automatic compression detection based on magic bytes. + + Supports gzip, bzip2, and xz/lzma compression formats. + """ + with open(filepath, "rb") as f: + magic = f.read(6) + + if magic[:2] == b"\x1f\x8b": # gzip magic number + return gzip.open(filepath, "rt", encoding="utf-8") + elif magic[:3] == b"BZh": # bzip2 magic number + return bz2.open(filepath, "rt", encoding="utf-8") + elif magic[:6] == b"\xfd7zXZ\x00": # xz magic number + return lzma.open(filepath, "rt", encoding="utf-8") + else: + return open(filepath, "r", encoding="utf-8") + + def _parse_feature_location(self, location_str: str) -> dict: + """Parse a GenBank feature location string into a structured dict. + + Examples: + "100..200" -> {"start": 100, "end": 200, "strand": 1} + "complement(100..200)" -> {"start": 100, "end": 200, "strand": -1} + "join(1..100,200..300)" -> {"start": 1, "end": 300, "strand": 1, "parts": [[1,100],[200,300]]} + """ + location = {"strand": 1} + + # Check for complement + if location_str.startswith("complement("): + location["strand"] = -1 + location_str = location_str[11:-1] # Remove "complement(" and ")" + + # Check for join + if location_str.startswith("join("): + location_str = location_str[5:-1] # Remove "join(" and ")" + parts = [] + for part in location_str.split(","): + part = part.strip() + if ".." in part: + start, end = part.split("..") + # Handle < and > symbols for partial sequences + start = int(start.lstrip("<>")) + end = int(end.lstrip("<>")) + parts.append([start, end]) + if parts: + location["parts"] = parts + location["start"] = parts[0][0] + location["end"] = parts[-1][1] + return location + + # Simple location + if ".." in location_str: + start, end = location_str.split("..") + location["start"] = int(start.lstrip("<>")) + location["end"] = int(end.lstrip("<>")) + elif location_str.isdigit(): + location["start"] = int(location_str) + location["end"] = int(location_str) + + return location + + def _parse_genbank(self, fp): + """State machine parser for GenBank format. + + GenBank format has several sections: + - LOCUS: Contains name, length, molecule type, etc. + - DEFINITION: Description of the sequence + - ACCESSION: Database accession number + - VERSION: Version with GI number + - KEYWORDS: Associated keywords + - SOURCE/ORGANISM: Organism information and taxonomy + - FEATURES: Detailed annotations + - ORIGIN: The actual sequence data + - // : Record terminator + + Args: + fp: File-like object opened in text mode. + + Yields: + Dict with parsed record fields for each GenBank record. + """ + state = ParserState.HEADER + record = self._new_record() + current_feature = None + current_qualifier_key = None + current_qualifier_value = [] + features_list = [] + + for line in fp: + # Record terminator + if line.startswith("//"): + # Finalize any pending feature + if current_feature is not None: + if current_qualifier_key is not None: + current_feature["qualifiers"][current_qualifier_key] = "".join(current_qualifier_value) + features_list.append(current_feature) + + # Store features + if self.config.parse_features: + record["features"] = json.dumps(features_list) + else: + record["features"] = "" + + yield record + + # Reset for next record + state = ParserState.HEADER + record = self._new_record() + current_feature = None + current_qualifier_key = None + current_qualifier_value = [] + features_list = [] + continue + + # State: HEADER - Parse metadata fields + if state == ParserState.HEADER: + if line.startswith("LOCUS"): + self._parse_locus_line(line, record) + elif line.startswith("DEFINITION"): + record["definition"] = line[12:].strip() + elif line.startswith("ACCESSION"): + record["accession"] = line[12:].strip().split()[0] + elif line.startswith("VERSION"): + record["version"] = line[12:].strip() + elif line.startswith("KEYWORDS"): + keywords = line[12:].strip() + if keywords != ".": + record["keywords"] = keywords + elif line.startswith("SOURCE"): + pass # SOURCE line itself is less useful than ORGANISM + elif line.startswith(" ORGANISM"): + record["organism"] = line[12:].strip() + elif line.startswith(" ") and record["organism"]: + # Continuation of taxonomy + taxonomy_part = line.strip() + if taxonomy_part and not taxonomy_part.endswith("."): + taxonomy_part += ";" + if record["taxonomy"]: + record["taxonomy"] += " " + taxonomy_part + else: + record["taxonomy"] = taxonomy_part + elif line.startswith("FEATURES"): + state = ParserState.FEATURES + elif line.startswith("ORIGIN"): + state = ParserState.ORIGIN + + # State: FEATURES - Parse feature annotations + elif state == ParserState.FEATURES: + if line.startswith("ORIGIN"): + # Finalize any pending feature before transitioning + if current_feature is not None: + if current_qualifier_key is not None: + current_feature["qualifiers"][current_qualifier_key] = "".join(current_qualifier_value) + features_list.append(current_feature) + current_feature = None + current_qualifier_key = None + current_qualifier_value = [] + state = ParserState.ORIGIN + continue + + # Feature line starts at column 5 with feature type + if len(line) > 5 and line[5] != " " and not line.startswith("FEATURES"): + # Finalize previous feature + if current_feature is not None: + if current_qualifier_key is not None: + current_feature["qualifiers"][current_qualifier_key] = "".join(current_qualifier_value) + features_list.append(current_feature) + + # Parse new feature + parts = line[5:].split() + if len(parts) >= 2: + feature_type = parts[0] + location_str = parts[1] + current_feature = { + "type": feature_type, + "location": self._parse_feature_location(location_str), + "qualifiers": {}, + } + current_qualifier_key = None + current_qualifier_value = [] + + # Qualifier line starts at column 21 + elif len(line) > 21 and line[21] == "/": + # Save previous qualifier + if current_qualifier_key is not None and current_feature is not None: + current_feature["qualifiers"][current_qualifier_key] = "".join(current_qualifier_value) + + # Parse new qualifier + qualifier_line = line[21:].strip() + if "=" in qualifier_line: + key, value = qualifier_line.split("=", 1) + current_qualifier_key = key[1:] # Remove leading / + # Remove surrounding quotes if present + value = value.strip('"') + current_qualifier_value = [value] + else: + # Boolean qualifier like /pseudo + current_qualifier_key = qualifier_line[1:] + current_qualifier_value = ["true"] + + # Continuation line for qualifier + elif len(line) > 21 and line[21] != "/" and current_qualifier_key is not None: + continuation = line[21:].strip().strip('"') + current_qualifier_value.append(continuation) + + # State: ORIGIN - Parse sequence data + elif state == ParserState.ORIGIN: + if line.startswith("//"): + continue # Will be handled at top of loop + + # Sequence lines have format: " 123 atcgatcg atcgatcg ..." + # Remove numbers and spaces, keep only sequence characters + seq_chars = re.sub(r"[\s\d]", "", line) + if seq_chars: + record["sequence"] += seq_chars.upper() + + def _new_record(self) -> dict: + """Create a new empty record with default values.""" + return { + "locus_name": "", + "accession": "", + "version": "", + "definition": "", + "organism": "", + "taxonomy": "", + "keywords": "", + "sequence": "", + "features": "", + "length": 0, + "molecule_type": "", + } + + def _parse_locus_line(self, line: str, record: dict) -> None: + """Parse the LOCUS line which contains key metadata. + + LOCUS format (fixed width columns): + LOCUS name length bp type topology division date + + Example: + LOCUS SCU49845 5028 bp DNA PLN 21-JUN-1999 + """ + # Split by whitespace and extract fields + parts = line.split() + if len(parts) >= 2: + record["locus_name"] = parts[1] + + # Find length (number followed by 'bp' or 'aa') + for i, part in enumerate(parts): + if part in ("bp", "aa") and i > 0: + try: + record["length"] = int(parts[i - 1]) + except ValueError: + pass + break + + # Find molecule type (DNA, RNA, mRNA, etc.) + molecule_types = {"DNA", "RNA", "mRNA", "rRNA", "tRNA", "protein", "AA"} + for part in parts: + if part in molecule_types: + record["molecule_type"] = part + break + + def _get_columns(self) -> list[str]: + """Get the list of columns to include in output.""" + if self.config.columns is not None: + # Validate columns + for col in self.config.columns: + if col not in self.ALL_COLUMNS: + raise ValueError(f"Invalid column '{col}'. Valid columns are: {self.ALL_COLUMNS}") + return self.config.columns + return self.ALL_COLUMNS + + def _get_schema(self, columns: list[str]) -> pa.Schema: + """Return Arrow schema with appropriate types for each column. + + Uses large_string for sequence and features columns to handle very long data + that can exceed the 2GB limit of regular string type. + """ + fields = [] + for col in columns: + if col in ("sequence", "features"): + # Use large_string for potentially very long data + fields.append(pa.field(col, pa.large_string())) + elif col == "length": + fields.append(pa.field(col, pa.int64())) + else: + fields.append(pa.field(col, pa.string())) + return pa.schema(fields) + + def _generate_tables(self, files): + """Generate Arrow tables from GenBank files. + + Yields batches of records as Arrow tables for memory-efficient processing + of large sequence files. Uses dual-threshold batching: flushes when either + batch_size (record count) or max_batch_bytes (cumulative size) is reached. + + Args: + files: Iterable of file iterables from _split_generators. + + Yields: + Tuple of (Key, pa.Table) for each batch. + """ + columns = self._get_columns() + schema = self._get_schema(columns) + max_batch_bytes = self.config.max_batch_bytes + + for file_idx, file in enumerate(itertools.chain.from_iterable(files)): + batch_idx = 0 + batch = {col: [] for col in columns} + batch_bytes = 0 + + with self._open_file(file) as fp: + for record in self._parse_genbank(fp): + # Update length from actual sequence if not set + if record["length"] == 0 and record["sequence"]: + record["length"] = len(record["sequence"]) + + # Calculate record size (approximate UTF-8 byte size) + record_bytes = sum( + len(str(record.get(col, ""))) for col in columns if col != "length" + ) + 8 # 8 bytes for int64 length + + # Check if adding this record would exceed byte limit + # Flush current batch first if needed (but only if batch is non-empty) + if ( + max_batch_bytes is not None + and batch_bytes > 0 + and batch_bytes + record_bytes > max_batch_bytes + ): + pa_table = pa.Table.from_pydict(batch, schema=schema) + yield Key(file_idx, batch_idx), self._cast_table(pa_table) + batch = {col: [] for col in columns} + batch_bytes = 0 + batch_idx += 1 + + # Add record to batch + for col in columns: + batch[col].append(record.get(col, "" if col != "length" else 0)) + batch_bytes += record_bytes + + # Yield batch when it reaches batch_size (record count limit) + if len(batch[columns[0]]) >= self.config.batch_size: + pa_table = pa.Table.from_pydict(batch, schema=schema) + yield Key(file_idx, batch_idx), self._cast_table(pa_table) + batch = {col: [] for col in columns} + batch_bytes = 0 + batch_idx += 1 + + # Yield remaining records in final batch + if batch[columns[0]]: + pa_table = pa.Table.from_pydict(batch, schema=schema) + yield Key(file_idx, batch_idx), self._cast_table(pa_table) diff --git a/tests/packaged_modules/test_genbank.py b/tests/packaged_modules/test_genbank.py new file mode 100644 index 00000000000..92c64b97647 --- /dev/null +++ b/tests/packaged_modules/test_genbank.py @@ -0,0 +1,697 @@ +"""Tests for GenBank file loader.""" + +import bz2 +import gzip +import json +import lzma +import textwrap + +import pyarrow as pa +import pytest + +from datasets import Features, Value +from datasets.builder import InvalidConfigName +from datasets.data_files import DataFilesList +from datasets.packaged_modules.genbank.genbank import GenBank, GenBankConfig + + +@pytest.fixture +def genbank_file(tmp_path): + """Create a simple GenBank file with a single record.""" + filename = tmp_path / "sequence.gb" + data = textwrap.dedent( + """\ + LOCUS SCU49845 5028 bp DNA PLN 21-JUN-1999 + DEFINITION Saccharomyces cerevisiae TCP1-beta gene, partial cds. + ACCESSION U49845 + VERSION U49845.1 + KEYWORDS . + SOURCE Saccharomyces cerevisiae (baker's yeast) + ORGANISM Saccharomyces cerevisiae + Eukaryota; Fungi; Dikarya; Ascomycota; Saccharomycotina; + Saccharomycetes. + FEATURES Location/Qualifiers + source 1..5028 + /organism="Saccharomyces cerevisiae" + /mol_type="genomic DNA" + CDS 687..3158 + /gene="TCP1-beta" + /product="TCP1-beta" + /protein_id="AAA98665.1" + ORIGIN + 1 gatcgatcga tcgatcgatc gatcgatcga tcgatcgatc gatcgatcga tcgatcgatc + 61 gatcgatcga tcgatcgatc + // + """ + ) + with open(filename, "w", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_multi_record(tmp_path): + """Create a GenBank file with multiple records.""" + filename = tmp_path / "multi_sequence.gb" + data = textwrap.dedent( + """\ + LOCUS SEQ001 100 bp DNA BCT 01-JAN-2024 + DEFINITION Test sequence 1. + ACCESSION SEQ001 + VERSION SEQ001.1 + KEYWORDS test. + SOURCE Escherichia coli + ORGANISM Escherichia coli + Bacteria; Proteobacteria; Gammaproteobacteria. + FEATURES Location/Qualifiers + source 1..100 + /organism="Escherichia coli" + gene 10..90 + /gene="testA" + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 61 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + // + LOCUS SEQ002 50 bp RNA VRL 01-JAN-2024 + DEFINITION Test sequence 2. + ACCESSION SEQ002 + VERSION SEQ002.1 + KEYWORDS . + SOURCE Test virus + ORGANISM Test virus + Viruses; RNA viruses. + FEATURES Location/Qualifiers + source 1..50 + /organism="Test virus" + ORIGIN + 1 augcaugcau gcaugcaugc augcaugcau gcaugcaugc augcaugcau + // + """ + ) + with open(filename, "w", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_gzipped(tmp_path): + """Create a gzipped GenBank file.""" + filename = tmp_path / "sequence.gb.gz" + data = textwrap.dedent( + """\ + LOCUS GZSEQ 80 bp DNA PLN 01-JAN-2024 + DEFINITION Gzipped test sequence. + ACCESSION GZSEQ + VERSION GZSEQ.1 + KEYWORDS gzip; test. + SOURCE Test organism + ORGANISM Test organism + Eukaryota; Testaceae. + FEATURES Location/Qualifiers + source 1..80 + /organism="Test organism" + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 61 atcgatcgat cgatcgatcg + // + """ + ) + with gzip.open(filename, "wt", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_bz2(tmp_path): + """Create a bzip2 compressed GenBank file.""" + filename = tmp_path / "sequence.gb.bz2" + data = textwrap.dedent( + """\ + LOCUS BZ2SEQ 60 bp DNA PLN 01-JAN-2024 + DEFINITION Bzip2 test sequence. + ACCESSION BZ2SEQ + VERSION BZ2SEQ.1 + KEYWORDS bzip2. + SOURCE Test organism + ORGANISM Test organism + Eukaryota; Testaceae. + FEATURES Location/Qualifiers + source 1..60 + /organism="Test organism" + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + // + """ + ) + with bz2.open(filename, "wt", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_xz(tmp_path): + """Create an xz/lzma compressed GenBank file.""" + filename = tmp_path / "sequence.gb.xz" + data = textwrap.dedent( + """\ + LOCUS XZSEQ 40 bp DNA PLN 01-JAN-2024 + DEFINITION XZ test sequence. + ACCESSION XZSEQ + VERSION XZSEQ.1 + KEYWORDS . + SOURCE Test organism + ORGANISM Test organism + Eukaryota; Testaceae. + FEATURES Location/Qualifiers + source 1..40 + /organism="Test organism" + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + // + """ + ) + with lzma.open(filename, "wt", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_complex_features(tmp_path): + """Create a GenBank file with complex feature locations.""" + filename = tmp_path / "complex_features.gb" + data = textwrap.dedent( + """\ + LOCUS COMPLEX 300 bp DNA PLN 01-JAN-2024 + DEFINITION Sequence with complex feature locations. + ACCESSION COMPLEX + VERSION COMPLEX.1 + KEYWORDS complex; features. + SOURCE Test organism + ORGANISM Test organism + Eukaryota; Testaceae. + FEATURES Location/Qualifiers + source 1..300 + /organism="Test organism" + gene complement(10..100) + /gene="revGene" + CDS join(1..50,100..150,200..250) + /gene="splitGene" + /product="split protein" + misc_feature <1..>300 + /note="partial feature" + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 61 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 121 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 181 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 241 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + // + """ + ) + with open(filename, "w", encoding="utf-8") as f: + f.write(data) + return str(filename) + + +@pytest.fixture +def genbank_file_large_sequences(tmp_path): + """Create a GenBank file with large sequences to test batching.""" + filename = tmp_path / "large_sequences.gb" + records = [] + for i in range(5): + seq_len = 1000 * (i + 1) # 1K, 2K, 3K, 4K, 5K bases + seq = "ACGT" * (seq_len // 4) + # Format sequence with GenBank-style line breaks + formatted_seq = "" + for j in range(0, len(seq), 60): + line_num = j + 1 + line_seq = seq[j : j + 60] + # Add spaces every 10 bases + spaced = " ".join(line_seq[k : k + 10] for k in range(0, len(line_seq), 10)) + formatted_seq += f"{line_num:>9} {spaced}\n" + + record = f"""LOCUS LARGE{i:03d} {seq_len} bp DNA PLN 01-JAN-2024 +DEFINITION Large sequence {i}. +ACCESSION LARGE{i:03d} +VERSION LARGE{i:03d}.1 +KEYWORDS large. +SOURCE Test organism + ORGANISM Test organism + Eukaryota; Testaceae. +FEATURES Location/Qualifiers + source 1..{seq_len} + /organism="Test organism" +ORIGIN +{formatted_seq}// +""" + records.append(record) + + with open(filename, "w", encoding="utf-8") as f: + f.write("\n".join(records)) + return str(filename) + + +def test_config_raises_when_invalid_name() -> None: + with pytest.raises(InvalidConfigName, match="Bad characters"): + _ = GenBankConfig(name="name-with-*-invalid-character") + + +@pytest.mark.parametrize("data_files", ["str_path", ["str_path"], DataFilesList(["str_path"], [()])]) +def test_config_raises_when_invalid_data_files(data_files) -> None: + with pytest.raises(ValueError, match="Expected a DataFilesDict"): + _ = GenBankConfig(name="name", data_files=data_files) + + +def test_genbank_basic_loading(genbank_file): + """Test basic GenBank file loading.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["locus_name"]) == 1 + assert result["locus_name"][0] == "SCU49845" + assert result["accession"][0] == "U49845" + assert result["version"][0] == "U49845.1" + assert "Saccharomyces cerevisiae TCP1-beta gene" in result["definition"][0] + assert result["organism"][0] == "Saccharomyces cerevisiae" + assert "Eukaryota" in result["taxonomy"][0] + assert result["length"][0] == 5028 + assert result["molecule_type"][0] == "DNA" + + +def test_genbank_multi_record(genbank_file_multi_record): + """Test loading GenBank file with multiple records.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file_multi_record]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["locus_name"]) == 2 + assert result["locus_name"] == ["SEQ001", "SEQ002"] + assert result["accession"] == ["SEQ001", "SEQ002"] + assert result["molecule_type"] == ["DNA", "RNA"] + assert result["organism"] == ["Escherichia coli", "Test virus"] + + +def test_genbank_gzipped(genbank_file_gzipped): + """Test loading gzipped GenBank files.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file_gzipped]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["locus_name"]) == 1 + assert result["locus_name"][0] == "GZSEQ" + assert result["keywords"][0] == "gzip; test." + + +def test_genbank_bz2(genbank_file_bz2): + """Test loading bzip2 compressed GenBank files.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file_bz2]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["locus_name"]) == 1 + assert result["locus_name"][0] == "BZ2SEQ" + assert result["keywords"][0] == "bzip2." + + +def test_genbank_xz(genbank_file_xz): + """Test loading xz/lzma compressed GenBank files.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file_xz]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["locus_name"]) == 1 + assert result["locus_name"][0] == "XZSEQ" + + +def test_genbank_feature_parsing(genbank_file_complex_features): + """Test parsing of complex feature locations.""" + genbank = GenBank(parse_features=True) + generator = genbank._generate_tables([[genbank_file_complex_features]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + features = json.loads(result["features"][0]) + + assert len(features) >= 3 + + # Find the complement feature + rev_gene = next((f for f in features if f.get("qualifiers", {}).get("gene") == "revGene"), None) + assert rev_gene is not None + assert rev_gene["location"]["strand"] == -1 + + # Find the join feature + split_gene = next((f for f in features if f.get("qualifiers", {}).get("gene") == "splitGene"), None) + assert split_gene is not None + assert "parts" in split_gene["location"] + assert len(split_gene["location"]["parts"]) == 3 + + +def test_genbank_feature_parsing_disabled(genbank_file): + """Test that feature parsing can be disabled.""" + genbank = GenBank(parse_features=False) + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + # Features should be empty string when parsing is disabled + assert result["features"][0] == "" + + +def test_genbank_column_filtering(genbank_file): + """Test loading with column subset.""" + genbank = GenBank(columns=["locus_name", "sequence", "length"]) + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert list(result.keys()) == ["locus_name", "sequence", "length"] + assert len(result["locus_name"]) == 1 + + +def test_genbank_column_filtering_single(genbank_file): + """Test loading with single column.""" + genbank = GenBank(columns=["sequence"]) + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert list(result.keys()) == ["sequence"] + + +def test_genbank_invalid_column(): + """Test that invalid column names raise an error.""" + genbank = GenBank(columns=["sequence", "invalid_column"]) + with pytest.raises(ValueError, match="Invalid column 'invalid_column'"): + list(genbank._generate_tables([[]])) + + +def test_genbank_batch_size(genbank_file_multi_record): + """Test batch size configuration.""" + genbank = GenBank(batch_size=1) + generator = genbank._generate_tables([[genbank_file_multi_record]]) + tables = [table for _, table in generator] + + # Should have 2 batches (one per record) + assert len(tables) == 2 + + for table in tables: + assert table.num_rows == 1 + + +def test_genbank_max_batch_bytes(genbank_file_large_sequences): + """Test byte-based batching with max_batch_bytes.""" + genbank = GenBank(batch_size=1000, max_batch_bytes=5000) + generator = genbank._generate_tables([[genbank_file_large_sequences]]) + tables = [table for _, table in generator] + + # Should create multiple batches due to byte limit + assert len(tables) > 1 + + +def test_genbank_no_byte_limit(genbank_file_large_sequences): + """Test disabling byte-based batching.""" + genbank = GenBank(batch_size=1000, max_batch_bytes=None) + generator = genbank._generate_tables([[genbank_file_large_sequences]]) + tables = [table for _, table in generator] + + # Should create single batch since batch_size is high + assert len(tables) == 1 + assert tables[0].num_rows == 5 + + +def test_genbank_schema_types(genbank_file): + """Test that schema uses correct Arrow types.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + schema = pa_table.schema + + # Regular string columns + assert schema.field("locus_name").type == pa.string() + assert schema.field("accession").type == pa.string() + assert schema.field("version").type == pa.string() + assert schema.field("definition").type == pa.string() + assert schema.field("organism").type == pa.string() + assert schema.field("taxonomy").type == pa.string() + assert schema.field("keywords").type == pa.string() + assert schema.field("molecule_type").type == pa.string() + + # Large string for sequence and features + assert schema.field("sequence").type == pa.large_string() + assert schema.field("features").type == pa.large_string() + + # Integer for length + assert schema.field("length").type == pa.int64() + + +def test_genbank_feature_casting(genbank_file): + """Test feature casting to custom schema.""" + features = Features( + { + "locus_name": Value("string"), + "accession": Value("string"), + "version": Value("string"), + "definition": Value("string"), + "organism": Value("string"), + "taxonomy": Value("string"), + "keywords": Value("string"), + "sequence": Value("large_string"), + "features": Value("large_string"), + "length": Value("int64"), + "molecule_type": Value("string"), + } + ) + genbank = GenBank(features=features) + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + assert pa_table.schema.field("locus_name").type == pa.string() + assert pa_table.schema.field("sequence").type == pa.large_string() + assert pa_table.schema.field("length").type == pa.int64() + + +def test_genbank_empty_file(tmp_path): + """Test handling of empty GenBank file.""" + filename = tmp_path / "empty.gb" + with open(filename, "w", encoding="utf-8") as f: + f.write("") + + genbank = GenBank() + generator = genbank._generate_tables([[str(filename)]]) + tables = list(generator) + + # Empty file should produce no tables + assert len(tables) == 0 + + +def test_genbank_sequence_parsing(genbank_file): + """Test that sequence is parsed correctly.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + # Sequence should be uppercase with no whitespace or numbers + sequence = result["sequence"][0] + assert sequence.isupper() + assert " " not in sequence + assert all(c in "ACGT" for c in sequence) + + +def test_genbank_multiple_files(tmp_path): + """Test loading multiple GenBank files.""" + file1 = tmp_path / "seq1.gb" + file2 = tmp_path / "seq2.gb" + + data1 = textwrap.dedent( + """\ + LOCUS FILE1SEQ 20 bp DNA PLN 01-JAN-2024 + DEFINITION File 1 sequence. + ACCESSION FILE1 + VERSION FILE1.1 + KEYWORDS . + SOURCE Test organism + ORGANISM Test organism + Eukaryota. + FEATURES Location/Qualifiers + source 1..20 + /organism="Test organism" + ORIGIN + 1 atcgatcgat cgatcgatcg + // + """ + ) + + data2 = textwrap.dedent( + """\ + LOCUS FILE2SEQ 20 bp DNA PLN 01-JAN-2024 + DEFINITION File 2 sequence. + ACCESSION FILE2 + VERSION FILE2.1 + KEYWORDS . + SOURCE Test organism + ORGANISM Test organism + Eukaryota. + FEATURES Location/Qualifiers + source 1..20 + /organism="Test organism" + ORIGIN + 1 gctagctagc tagctagcta + // + """ + ) + + with open(file1, "w", encoding="utf-8") as f: + f.write(data1) + with open(file2, "w", encoding="utf-8") as f: + f.write(data2) + + genbank = GenBank() + generator = genbank._generate_tables([[str(file1)], [str(file2)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert len(result["accession"]) == 2 + assert "FILE1" in result["accession"] + assert "FILE2" in result["accession"] + + +def test_genbank_extensions(): + """Test that correct extensions are defined.""" + assert ".gb" in GenBank.EXTENSIONS + assert ".gbk" in GenBank.EXTENSIONS + assert ".genbank" in GenBank.EXTENSIONS + + +def test_genbank_all_columns(): + """Test that all expected columns are defined.""" + expected_columns = [ + "locus_name", + "accession", + "version", + "definition", + "organism", + "taxonomy", + "keywords", + "sequence", + "features", + "length", + "molecule_type", + ] + assert GenBank.ALL_COLUMNS == expected_columns + + +def test_genbank_locus_parsing_variations(tmp_path): + """Test parsing different LOCUS line formats.""" + filename = tmp_path / "locus_variations.gb" + # Minimal LOCUS line + data = textwrap.dedent( + """\ + LOCUS MINSEQ 100 bp mRNA 01-JAN-2024 + DEFINITION Minimal sequence. + ACCESSION MINSEQ + VERSION MINSEQ.1 + KEYWORDS . + SOURCE Test + ORGANISM Test + Test. + FEATURES Location/Qualifiers + source 1..100 + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + 61 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg + // + """ + ) + with open(filename, "w", encoding="utf-8") as f: + f.write(data) + + genbank = GenBank() + generator = genbank._generate_tables([[str(filename)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + assert result["locus_name"][0] == "MINSEQ" + assert result["length"][0] == 100 + assert result["molecule_type"][0] == "mRNA" + + +def test_genbank_keywords_empty(genbank_file): + """Test that '.' keywords are handled correctly.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + # The fixture has KEYWORDS . which should result in empty keywords + assert result["keywords"][0] == "" + + +def test_genbank_taxonomy_continuation(genbank_file): + """Test multi-line taxonomy parsing.""" + genbank = GenBank() + generator = genbank._generate_tables([[genbank_file]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + + # Taxonomy should include continuation lines + taxonomy = result["taxonomy"][0] + assert "Eukaryota" in taxonomy + assert "Fungi" in taxonomy + + +def test_genbank_feature_boolean_qualifier(tmp_path): + """Test parsing of boolean qualifiers like /pseudo.""" + filename = tmp_path / "boolean_qual.gb" + data = textwrap.dedent( + """\ + LOCUS BOOLSEQ 50 bp DNA PLN 01-JAN-2024 + DEFINITION Sequence with boolean qualifier. + ACCESSION BOOLSEQ + VERSION BOOLSEQ.1 + KEYWORDS . + SOURCE Test + ORGANISM Test + Test. + FEATURES Location/Qualifiers + gene 1..50 + /gene="testGene" + /pseudo + ORIGIN + 1 atcgatcgat cgatcgatcg atcgatcgat cgatcgatcg atcgatcgat + // + """ + ) + with open(filename, "w", encoding="utf-8") as f: + f.write(data) + + genbank = GenBank(parse_features=True) + generator = genbank._generate_tables([[str(filename)]]) + pa_table = pa.concat_tables([table for _, table in generator]) + + result = pa_table.to_pydict() + features = json.loads(result["features"][0]) + + gene_feature = next((f for f in features if f["type"] == "gene"), None) + assert gene_feature is not None + assert gene_feature["qualifiers"].get("pseudo") == "true"