Skip to content

Commit

Permalink
fixed lint errors, changed required argument to default in TV2DNorm, …
Browse files Browse the repository at this point in the history
…fixed inconsistent signature for prox function, added more comments to the helper functions
  • Loading branch information
Salman Naqvi committed Oct 5, 2023
1 parent 9d1d73a commit 877df4c
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions scico/functional/_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ class TV2DNorm(Functional):
has_eval = True
has_prox = True

def __init__(self, dims: Tuple[int, int], tau: float = 1.0):
def __init__(self, dims: Tuple[int, int] = (1,1), tau: float = 1.0):
r"""
Args:
tau: Parameter :math:`\tau` in the norm definition.
Expand All @@ -520,7 +520,7 @@ def __call__(self, x: Union[Array, BlockArray]) -> float:
return self.tau * snp.sum(y)

def prox(
self, x: Union[Array, BlockArray], lam: float = 1.0, **kwargs
self, v: Union[Array, BlockArray], lam: float = 1.0, **kwargs
) -> Union[Array, BlockArray]:
r"""Proximal operator of the :math:`\ell_{TV}` norm.
Expand All @@ -533,20 +533,21 @@ def prox(
kwargs: Additional arguments that may be used by derived
classes.
"""
assert x.shape == self.dims
assert v.shape == self.dims
D = 2
K = 2*D
thresh = snp.sqrt(2) * K * self.tau * lam

y = snp.zeros_like(x)
y = snp.zeros_like(v)
for ax in range(2):
y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=False), thresh), axis=ax, shift=False))
y = y.at[:].add(self.iht2(self.shrink(self.ht2(x, axis=ax, shift=True), thresh), axis=ax, shift=True))
y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=False), thresh), axis=ax, shift=False))
y = y.at[:].add(self.iht2(self.shrink(self.ht2(v, axis=ax, shift=True), thresh), axis=ax, shift=True))
y = y.at[:].divide(K)

return y

def ht2(self, x, axis, shift):
r"""Forward Discrete Haar Wavelet transform in 2D"""
s = x.shape
w = snp.zeros(s)
C = 1 / snp.sqrt(2)
Expand All @@ -563,6 +564,7 @@ def ht2(self, x, axis, shift):
return w

def iht2(self, w, axis, shift):
r"""Inverse Discrete Haar Wavelet transform in 2D"""
s = snp.shape(w)
y = snp.zeros(s)
C = 1 / snp.sqrt(2)
Expand All @@ -580,6 +582,7 @@ def iht2(self, w, axis, shift):
return y

def shrink(self, x, tau):
r"""Wavelet shrinkage operator"""
threshed = snp.maximum(snp.abs(x)-tau, 0)
threshed = threshed.at[:].multiply(snp.sign(x))
return threshed

0 comments on commit 877df4c

Please sign in to comment.