diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 470f5a95..a350a8b3 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -8,7 +8,8 @@ import json import logging import os -from dataclasses import dataclass +import urllib +from dataclasses import dataclass, Field, field from typing import ( Any, ClassVar, @@ -29,9 +30,9 @@ import torch from pytorch3d.implicitron.dataset.dataset_base import DatasetBase -from pytorch3d.implicitron.dataset.frame_data import ( # noqa +from pytorch3d.implicitron.dataset.frame_data import ( FrameData, - FrameDataBuilder, + FrameDataBuilder, # noqa FrameDataBuilderBase, ) from pytorch3d.implicitron.tools.config import ( @@ -51,7 +52,7 @@ @registry.register -class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore +class SqlIndexDataset(DatasetBase, ReplaceableBase): """ A dataset with annotations stored as SQLite tables. This is an index-based dataset. The length is returned after all sequence and frame filters are applied (see param @@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore seed: int = 0 remove_empty_masks_poll_whole_table_threshold: int = 300_000 # we set it manually in the constructor - # _index: pd.DataFrame = field(init=False) - - frame_data_builder: FrameDataBuilderBase + _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True}) + _sql_engine: sa.engine.Engine = field( + init=False, metadata={"omegaconf_ignore": True} + ) + eval_batches: Optional[List[Any]] = field( + init=False, metadata={"omegaconf_ignore": True} + ) + + frame_data_builder: FrameDataBuilderBase # pyre-ignore[13] frame_data_builder_class_type: str = "FrameDataBuilder" def __post_init__(self) -> None: @@ -138,17 +145,23 @@ def __post_init__(self) -> None: raise ValueError("sqlite_metadata_file must be set") if self.dataset_root: - frame_builder_type = self.frame_data_builder_class_type - getattr(self, f"frame_data_builder_{frame_builder_type}_args")[ - "dataset_root" - ] = self.dataset_root + frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args" + getattr(self, frame_args)["dataset_root"] = self.dataset_root + getattr(self, frame_args)["path_manager"] = self.path_manager run_auto_creation(self) - self.frame_data_builder.path_manager = self.path_manager - # pyre-ignore # NOTE: sqlite-specific args (read-only mode). + if self.path_manager is not None: + self.sqlite_metadata_file = self.path_manager.get_local_path( + self.sqlite_metadata_file + ) + self.subset_lists_file = self.path_manager.get_local_path( + self.subset_lists_file + ) + + # NOTE: sqlite-specific args (read-only mode). self._sql_engine = sa.create_engine( - f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true" + f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true" ) sequences = self._get_filtered_sequences_if_any() @@ -166,16 +179,15 @@ def __post_init__(self) -> None: if len(index) == 0: raise ValueError(f"There are no frames in the subsets: {self.subsets}!") - self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore + self._index = index.set_index(["sequence_name", "frame_number"]) - self.eval_batches = None # pyre-ignore + self.eval_batches = None if self.eval_batches_file: self.eval_batches = self._load_filter_eval_batches() logger.info(str(self)) def __len__(self) -> int: - # pyre-ignore[16] return len(self._index) def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: @@ -250,7 +262,6 @@ def _get_item( return frame_data def __str__(self) -> str: - # pyre-ignore[16] return f"SqlIndexDataset #frames={len(self._index)}" def sequence_names(self) -> Iterable[str]: @@ -335,12 +346,12 @@ def sequence_frames_in_order( rows = self._index.index.get_loc(seq_name) if isinstance(rows, slice): assert rows.stop is not None, "Unexpected result from pandas" - rows = range(rows.start or 0, rows.stop, rows.step or 1) + rows_seq = range(rows.start or 0, rows.stop, rows.step or 1) else: - rows = np.where(rows)[0] + rows_seq = list(np.where(rows)[0]) index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices( - rows, seq_name, subset_filter + rows_seq, seq_name, subset_filter ) index_slice["idx"] = idx @@ -461,14 +472,15 @@ def _get_exclude_filters(self) -> List[sa.ColumnOperators]: return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: - assert self.subsets is not None + subsets = self.subsets + assert subsets is not None with open(subset_lists_path, "r") as f: subset_to_seq_frame = json.load(f) seq_frame_list = sum( ( [(*row, subset) for row in subset_to_seq_frame[subset]] - for subset in self.subsets + for subset in subsets ), [], ) @@ -522,7 +534,7 @@ def _build_index_from_subset_lists( stmt = sa.select( self.frame_annotations_type.sequence_name, self.frame_annotations_type.frame_number, - ).where(self.frame_annotations_type._mask_mass == 0) + ).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16] with Session(self._sql_engine) as session: to_remove = session.execute(stmt).all() @@ -586,7 +598,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]): stmt = sa.select( self.frame_annotations_type.sequence_name, self.frame_annotations_type.frame_number, - self.frame_annotations_type._image_path, + self.frame_annotations_type._image_path, # pyre-ignore[16] sa.null().label("subset"), ) where_conditions = [] @@ -600,7 +612,7 @@ def _build_index_from_db(self, sequences: Optional[pd.Series]): logger.info(" excluding samples with empty masks") where_conditions.append( sa.or_( - self.frame_annotations_type._mask_mass.is_(None), + self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16] self.frame_annotations_type._mask_mass != 0, ) ) @@ -634,7 +646,9 @@ def _load_filter_eval_batches(self): assert self.eval_batches_file logger.info(f"Loading eval batches from {self.eval_batches_file}") - if not os.path.isfile(self.eval_batches_file): + if ( + self.path_manager and not self.path_manager.isfile(self.eval_batches_file) + ) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)): # The batch indices file does not exist. # Most probably the user has not specified the root folder. raise ValueError( @@ -642,7 +656,8 @@ def _load_filter_eval_batches(self): + "Please specify a correct dataset_root folder." ) - with open(self.eval_batches_file, "r") as f: + eval_batches_file = self._local_path(self.eval_batches_file) + with open(eval_batches_file, "r") as f: eval_batches = json.load(f) # limit the dataset to sequences to allow multiple evaluations in one file @@ -758,11 +773,18 @@ def _get_temp_index_table_instance(self, table_name: str = "__index"): prefixes=["TEMP"], # NOTE SQLite specific! ) + @classmethod + def pre_expand(cls) -> None: + # remove dataclass annotations that are not meant to be init params + # because they cause troubles for OmegaConf + for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate + if isinstance(attr_value, Field) and attr_value.metadata.get( + "omegaconf_ignore", False + ): + delattr(cls, attr) + del cls.__annotations__[attr] + def _seq_name_to_seed(seq_name) -> int: """Generates numbers in [0, 2 ** 28)""" return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16) - - -def _safe_as_tensor(data, dtype): - return torch.tensor(data, dtype=dtype) if data is not None else None diff --git a/pytorch3d/implicitron/dataset/sql_dataset_provider.py b/pytorch3d/implicitron/dataset/sql_dataset_provider.py index ab161e8d..08aa5781 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset_provider.py +++ b/pytorch3d/implicitron/dataset/sql_dataset_provider.py @@ -43,7 +43,7 @@ @registry.register -class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] +class SqlIndexDatasetMapProvider(DatasetMapProviderBase): """ Generates the training, validation, and testing dataset objects for a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base. @@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] # this is a mould that is never constructed, used to build self._dataset_map values dataset_class_type: str = "SqlIndexDataset" - dataset: SqlIndexDataset + dataset: SqlIndexDataset # pyre-ignore [13] - path_manager_factory: PathManagerFactory + path_manager_factory: PathManagerFactory # pyre-ignore [13] path_manager_factory_class_type: str = "PathManagerFactory" def __post_init__(self):