Skip to content

Commit 77f24d7

Browse files
amyoshinoseuabeia
andauthored
adding icdf function for Cauchy and Logistic with tests (#6747)
Co-authored-by: seuabeia <[email protected]>
1 parent 69514ac commit 77f24d7

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pymc/distributions/continuous.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -1714,7 +1714,6 @@ def logcdf(value, mu, sigma):
17141714
-np.inf,
17151715
normal_lcdf(mu, sigma, pt.log(value)),
17161716
)
1717-
17181717
return check_parameters(
17191718
res,
17201719
sigma > 0,
@@ -2039,6 +2038,15 @@ def logcdf(value, alpha, beta):
20392038
msg="beta > 0",
20402039
)
20412040

2041+
def icdf(value, alpha, beta):
2042+
res = alpha + beta * pt.tan(np.pi * (value - 0.5))
2043+
res = check_icdf_value(res, value)
2044+
return check_parameters(
2045+
res,
2046+
beta > 0,
2047+
msg="beta > 0",
2048+
)
2049+
20422050

20432051
class HalfCauchy(PositiveContinuous):
20442052
r"""
@@ -3357,6 +3365,15 @@ def logcdf(value, mu, s):
33573365
msg="s > 0",
33583366
)
33593367

3368+
def icdf(value, mu, s):
3369+
res = mu + s * pt.log(value / (1 - value))
3370+
res = check_icdf_value(res, value)
3371+
return check_parameters(
3372+
res,
3373+
s > 0,
3374+
msg="s > 0",
3375+
)
3376+
33603377

33613378
class LogitNormalRV(RandomVariable):
33623379
name = "logit_normal"

tests/distributions/test_continuous.py

+11
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,11 @@ def test_cauchy(self):
548548
{"alpha": R, "beta": Rplusbig},
549549
lambda value, alpha, beta: st.cauchy.logcdf(value, alpha, beta),
550550
)
551+
check_icdf(
552+
pm.Cauchy,
553+
{"alpha": R, "beta": Rplusbig},
554+
lambda q, alpha, beta: st.cauchy.ppf(q, alpha, beta),
555+
)
551556

552557
def test_half_cauchy(self):
553558
check_logp(
@@ -768,6 +773,12 @@ def test_logistic(self):
768773
lambda value, mu, s: st.logistic.logcdf(value, mu, s),
769774
decimal=select_by_precision(float64=6, float32=1),
770775
)
776+
check_icdf(
777+
pm.Logistic,
778+
{"mu": R, "s": Rplus},
779+
lambda q, mu, s: st.logistic.ppf(q, mu, s),
780+
decimal=select_by_precision(float64=6, float32=1),
781+
)
771782

772783
def test_logitnormal(self):
773784
check_logp(

0 commit comments

Comments
 (0)