Skip to content
1 change: 1 addition & 0 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Trainer, Dataset and Datamodule
Trainer <trainer.rst>
Dataset <data/dataset.rst>
DataModule <data/data_module.rst>
Dataloader <data/dataloader.rst>

Data Types
------------
Expand Down
8 changes: 0 additions & 8 deletions docs/source/_rst/data/data_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,6 @@ DataModule
======================
.. currentmodule:: pina.data.data_module

.. autoclass:: Collator
:members:
:show-inheritance:

.. autoclass:: PinaDataModule
:members:
:show-inheritance:

.. autoclass:: PinaSampler
:members:
:show-inheritance:
11 changes: 11 additions & 0 deletions docs/source/_rst/data/dataloader.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Dataloader
======================
.. currentmodule:: pina.data.dataloader

.. autoclass:: PinaSampler
:members:
:show-inheritance:

.. autoclass:: PinaDataLoader
:members:
:show-inheritance:
8 changes: 0 additions & 8 deletions docs/source/_rst/data/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,5 @@ Dataset
:show-inheritance:

.. autoclass:: PinaDatasetFactory
:members:
:show-inheritance:

.. autoclass:: PinaGraphDataset
:members:
:show-inheritance:

.. autoclass:: PinaTensorDataset
:members:
:show-inheritance:
31 changes: 14 additions & 17 deletions pina/callback/normalizer_data_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from ..label_tensor import LabelTensor
from ..utils import check_consistency, is_function
from ..condition import InputTargetCondition
from ..data.dataset import PinaGraphDataset


class NormalizerDataCallback(Callback):
Expand Down Expand Up @@ -122,7 +121,10 @@ def setup(self, trainer, pl_module, stage):
"""

# Ensure datsets are not graph-based
if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset):
if any(
ds.is_graph_dataset
for ds in trainer.datamodule.train_dataset.values()
):
raise NotImplementedError(
"NormalizerDataCallback is not compatible with "
"graph-based datasets."
Expand Down Expand Up @@ -164,8 +166,8 @@ def _compute_scale_shift(self, conditions, dataset):
:param dataset: The `~pina.data.dataset.PinaDataset` dataset.
"""
for cond in conditions:
if cond in dataset.conditions_dict:
data = dataset.conditions_dict[cond][self.apply_to]
if cond in dataset:
data = dataset[cond].data[self.apply_to]
shift = self.shift_fn(data)
scale = self.scale_fn(data)
self._normalizer[cond] = {
Expand Down Expand Up @@ -197,25 +199,20 @@ def normalize_dataset(self, dataset):

:param PinaDataset dataset: The dataset to be normalized.
"""
# Initialize update dictionary
update_dataset_dict = {}

# Iterate over conditions and apply normalization
for cond, norm_params in self.normalizer.items():
points = dataset.conditions_dict[cond][self.apply_to]
update_dataset_dict = {}
points = dataset[cond].data[self.apply_to]
scale = norm_params["scale"]
shift = norm_params["shift"]
normalized_points = self._norm_fn(points, scale, shift)
update_dataset_dict[cond] = {
self.apply_to: (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
}

# Update the dataset in-place
dataset.update_data(update_dataset_dict)
update_dataset_dict[self.apply_to] = (
LabelTensor(normalized_points, points.labels)
if isinstance(points, LabelTensor)
else normalized_points
)
dataset[cond].data.update(update_dataset_dict)

@property
def normalizer(self):
Expand Down
14 changes: 6 additions & 8 deletions pina/callback/refinement/refinement_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,12 @@ def _update_points(self, solver):

:param PINNInterface solver: The solver object.
"""
new_points = {}
for name in self._condition_to_update:
current_points = self.dataset.conditions_dict[name]["input"]
new_points[name] = {
"input": self.sample(current_points, name, solver)
}
self.dataset.update_data(new_points)
new_points = {}
current_points = self.dataset[name].data["input"]
new_points["input"] = self.sample(current_points, name, solver)

self.dataset[name].update_data(new_points)

def _compute_population_size(self, conditions):
"""
Expand All @@ -150,6 +149,5 @@ def _compute_population_size(self, conditions):
:rtype: dict
"""
return {
cond: len(self.dataset.conditions_dict[cond]["input"])
for cond in conditions
cond: len(self.dataset[cond].data["input"]) for cond in conditions
}
1 change: 0 additions & 1 deletion pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@


from .data_module import PinaDataModule
from .dataset import PinaDataset
Loading