Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/homology.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Undirected simplicial homology
homology.SparseRipsPersistence
homology.WeakAlphaPersistence
homology.EuclideanCechPersistence
homology.LowerStarFlagPersistence

Directed simplicial homology
----------------------------
Expand Down
1 change: 1 addition & 0 deletions doc/modules/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
plotting.plot_point_cloud
plotting.plot_heatmap
plotting.plot_diagram
plotting.plot_extended_diagram
plotting.plot_betti_curves
plotting.plot_betti_surfaces
3 changes: 2 additions & 1 deletion gtda/homology/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from .simplicial import VietorisRipsPersistence, WeightedRipsPersistence, \
SparseRipsPersistence, WeakAlphaPersistence, EuclideanCechPersistence, \
FlagserPersistence
FlagserPersistence, LowerStarFlagPersistence
from .cubical import CubicalPersistence

__all__ = [
Expand All @@ -14,5 +14,6 @@
'WeakAlphaPersistence',
'EuclideanCechPersistence',
'FlagserPersistence',
'LowerStarFlagPersistence',
'CubicalPersistence',
]
28 changes: 18 additions & 10 deletions gtda/homology/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def replace_infinity_values(subdiagram):
for dim in homology_dimensions}
Xt = [{dim: replace_infinity_values(diagram[dim][slices[dim]])
for dim in homology_dimensions} for diagram in Xt]
feature_vect_len = 3
elif format == "gudhi": # Input is list of list of [dim, (birth, death)]
# In H0, remove one infinite bar placed at the beginning by GUDHI only
# if `reduce` is True
Expand All @@ -31,23 +32,28 @@ def replace_infinity_values(subdiagram):
if pers_info[0] == dim]).reshape(-1, 2)[slices[dim]]
)
for dim in homology_dimensions} for diagram in Xt]
feature_vect_len = 3
elif format == "extended": # Input is list of list of subdiagrams
Xt = [{dim: diagram[dim]
for dim in homology_dimensions} for diagram in Xt]
feature_vect_len = 4
else:
raise ValueError(
f"Unknown input format {format} for collection of diagrams."
)
raise ValueError(f"Unknown input format {format} for collection of "
f"diagrams.")

# Conversion to array of triples with padding triples
# Conversion to array of triples/quadruples with padding
start_idx_per_dim = np.cumsum(
[0] + [np.max([len(diagram[dim]) for diagram in Xt] + [1])
for dim in homology_dimensions]
)
[0] + [np.max([len(diagram[dim]) for diagram in Xt] + [1])
for dim in homology_dimensions]
)
min_values = [min([np.min(diagram[dim][:, 0]) if diagram[dim].size
else np.inf for diagram in Xt])
for dim in homology_dimensions]
min_values = [min_value if min_value != np.inf else 0
for min_value in min_values]
n_features = start_idx_per_dim[-1]
Xt_padded = np.empty((len(Xt), n_features, 3), dtype=float)
Xt_padded = np.empty((len(Xt), n_features, feature_vect_len), dtype=float)
Xt_padded[:, :, 3:] = 1. # Only applies to extended persistence

for i, dim in enumerate(homology_dimensions):
start_idx, end_idx = start_idx_per_dim[i:i + 2]
Expand All @@ -58,8 +64,10 @@ def replace_infinity_values(subdiagram):
subdiagram = diagram[dim]
end_idx_nontrivial = start_idx + len(subdiagram)
# Populate nontrivial part of the subdiagram
Xt_padded[j, start_idx:end_idx_nontrivial, :2] = subdiagram
# Insert padding triples
Xt_padded[j, start_idx:end_idx_nontrivial, :2] = subdiagram[:, :2]
Xt_padded[j, start_idx:end_idx_nontrivial, 3:] = \
subdiagram[:, 2:] # Only applies to extended persistence
# Insert padding triples/quadruples
Xt_padded[j, end_idx_nontrivial:end_idx, :2] = [padding_value] * 2

return Xt_padded
179 changes: 178 additions & 1 deletion gtda/homology/simplicial.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._utils import _postprocess_diagrams
from ..base import PlotterMixin
from ..externals.python import ripser, SparseRipsComplex, CechComplex
from ..plotting import plot_diagram
from ..plotting import plot_diagram, plot_extended_diagram
from ..utils._docs import adapt_fit_transform_docs
from ..utils.intervals import Interval
from ..utils.validation import validate_params, check_point_clouds
Expand Down Expand Up @@ -1772,3 +1772,180 @@ def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None):
Xt[sample], homology_dimensions=homology_dimensions,
plotly_params=plotly_params
)


# @adapt_fit_transform_docs
class LowerStarFlagPersistence(BaseEstimator, TransformerMixin, PlotterMixin):
"""TODO"""
_hyperparameters = {
"homology_dimensions": {
"type": (list, tuple),
"of": {"type": int, "in": Interval(0, np.inf, closed="left")}
},
"coeff": {"type": int, "in": Interval(2, np.inf, closed="left")},
"extended": {"type": bool},
"infinity_values": {"type": (Real, type(None))},
"reduced_homology": {"type": bool},
"collapse_edges": {"type": bool}
}

def __init__(self, homology_dimensions=(0, 1), coeff=2, extended=True,
infinity_values=np.inf, reduced_homology=False,
collapse_edges=False, n_jobs=None):
self.homology_dimensions = homology_dimensions
self.coeff = coeff
self.extended = extended
self.infinity_values = infinity_values
self.reduced_homology = reduced_homology
self.collapse_edges = collapse_edges
self.n_jobs = n_jobs

def _lower_star_extended_diagram(self, X):
n_points = max(X.shape)
n_points_cone = n_points + 1
n_diag = min(X.shape)
X = X.tocoo()

data_diag = np.zeros(n_points_cone, dtype=X.dtype)
data_diag[:n_diag] = X.diagonal()
max_value = data_diag[:n_diag].max()
min_value = data_diag[:n_diag].min()
data_diag[-1] = min_value - 1

off_diag = X.row != X.col
row_off_diag, col_off_diag = X.row[off_diag], X.col[off_diag]

row = np.concatenate([row_off_diag, np.arange(n_points)])
col = np.concatenate([col_off_diag, np.full(n_points, n_points)])
data = np.concatenate([
np.maximum(data_diag[row_off_diag], data_diag[col_off_diag]),
2 * max_value + 1 - data_diag[:n_points]
])

X = coo_matrix((data, (row, col)),
shape=(n_points_cone, n_points_cone))
X.setdiag(data_diag)

Xdgms = ripser(
X, metric="precomputed", maxdim=self._max_homology_dimension,
coeff=self.coeff, collapse_edges=self.collapse_edges
)["dgms"]

for i in range(len(Xdgms)):
mask_down_sweep = Xdgms[i] > max_value
sgn = 2 * np.logical_not(
np.logical_xor.reduce(mask_down_sweep, axis=1,
keepdims=True)).astype(int) - 1
Xdgms[i][mask_down_sweep] = \
2 * max_value + 1 - Xdgms[i][mask_down_sweep]
if not i:
Xdgms[i] = np.hstack([Xdgms[i][:-1, :], sgn[:-1, :]])
else:
Xdgms[i] = np.hstack([Xdgms[i], sgn])

return Xdgms

def _lower_star_diagram(self, X):
n_points = max(X.shape)
n_diag = min(X.shape)
X = X.tocoo()

data_diag = np.zeros(n_points, dtype=X.dtype)
data_diag[:n_diag] = X.diagonal()

off_diag = X.row != X.col
row_off_diag, col_off_diag = X.row[off_diag], X.col[off_diag]

row = np.concatenate([row_off_diag, np.arange(n_points)])
col = np.concatenate([col_off_diag, np.arange(n_points)])
data = np.concatenate([
np.maximum(data_diag[row_off_diag], data_diag[col_off_diag]),
data_diag]
)

X = coo_matrix((data, (row, col)))

Xdgms = ripser(
X, metric="precomputed", maxdim=self._max_homology_dimension,
coeff=self.coeff, collapse_edges=self.collapse_edges
)["dgms"]

return Xdgms

def fit(self, X, y=None):
"""TODO"""
check_point_clouds(X, accept_sparse=True, distance_matrices=True)
validate_params(
self.get_params(), self._hyperparameters, exclude=["n_jobs"])

self._homology_dimensions = sorted(self.homology_dimensions)
self._max_homology_dimension = self._homology_dimensions[-1]

if self.extended:
self._diagram_computer = self._lower_star_extended_diagram
self._format = "extended"
else:
self._diagram_computer = self._lower_star_diagram
self._format = "ripser"

self._is_fitted = True

return self

def transform(self, X, y=None):
"""TODO"""
check_is_fitted(self, "_is_fitted")
X = check_point_clouds(X, accept_sparse=True, distance_matrices=True)

Xt = Parallel(n_jobs=self.n_jobs)(
delayed(self._diagram_computer)(x) for x in X)

format = "extended" if self.extended else "ripser"
Xt = _postprocess_diagrams(
Xt, format, self._homology_dimensions, self.infinity_values,
self.reduced_homology
)
return Xt

@staticmethod
def plot(Xt, sample=0, homology_dimensions=None, plotly_params=None):
"""Plot a sample from a collection of persistence diagrams, with
homology in multiple dimensions.

Parameters
----------
Xt : ndarray of shape (n_samples, n_features, 3)
Collection of persistence diagrams, such as returned by
:meth:`transform`.

sample : int, optional, default: ``0``
Index of the sample in `Xt` to be plotted.

homology_dimensions : list, tuple or None, optional, default: ``None``
Which homology dimensions to include in the plot. ``None`` means
plotting all dimensions present in ``Xt[sample]``.

plotly_params : dict or None, optional, default: ``None``
Custom parameters to configure the plotly figure. Allowed keys are
``"traces"`` and ``"layout"``, and the corresponding values should
be dictionaries containing keyword arguments as would be fed to the
:meth:`update_traces` and :meth:`update_layout` methods of
:class:`plotly.graph_objects.Figure`.

Returns
-------
fig : :class:`plotly.graph_objects.Figure` object
Plotly figure.

"""
Xt_sample = Xt[sample]
if Xt_sample.shape[1] == 4:
return plot_extended_diagram(
Xt_sample, homology_dimensions=homology_dimensions,
plotly_params=plotly_params
)

return plot_diagram(
Xt_sample, homology_dimensions=homology_dimensions,
plotly_params=plotly_params
)
3 changes: 2 additions & 1 deletion gtda/plotting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@
giotto-tda transformers."""

from .point_clouds import plot_point_cloud
from .persistence_diagrams import plot_diagram
from .persistence_diagrams import plot_diagram, plot_extended_diagram
from .diagram_representations import plot_betti_curves, plot_betti_surfaces
from .images import plot_heatmap

__all__ = [
'plot_point_cloud',
'plot_diagram',
'plot_extended_diagram',
'plot_heatmap',
'plot_betti_curves',
'plot_betti_surfaces'
Expand Down
Loading