Skip to content

Commit 5eeaa57

Browse files
committed
modified unit tests
1 parent c553042 commit 5eeaa57

File tree

2 files changed

+86
-123
lines changed

2 files changed

+86
-123
lines changed

benches/test_bm_example1.py

Lines changed: 77 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -5,138 +5,101 @@
55

66
from ellalgo.cutting_plane import Options, cutting_plane_optim
77
from ellalgo.ell import Ell
8-
from ellalgo.ell_typing import OracleFeas, OracleOptim
8+
from ellalgo.ell_typing import OracleOptim
99

1010

11-
class MyOracleFeas(OracleFeas):
12-
idx = 0
13-
gamma = 0.0
14-
15-
# constraint 1: x + y <= 3
16-
def fn1(self, x, y):
17-
return x + y - 3
18-
19-
# constraint 2: x - y >= 1
20-
def fn2(self, x, y):
21-
return -x + y + 1
22-
23-
def fn0(self, x, y):
24-
return self.gamma - (x + y)
25-
26-
def grad1(self):
27-
return np.array([1.0, 1.0])
28-
29-
def grad2(self):
30-
return np.array([-1.0, 1.0])
31-
32-
def grad0(self):
33-
return np.array([-1.0, -1.0])
11+
class MyOracle1(OracleOptim):
12+
"""
13+
This Python class `MyOracle1` contains a method `assess_optim` that assesses optimization based on
14+
given parameters and returns specific values accordingly.
15+
"""
3416

3517
def __init__(self):
36-
self.fns = (self.fn1, self.fn2, self.fn0)
37-
self.grads = (self.grad1, self.grad2, self.grad0)
18+
"""
19+
Creates a new `MyOracle` instance with the `idx` field initialized to 0.
3820
39-
def assess_feas(self, z):
40-
"""[summary]
21+
This is the constructor for the `MyOracle` class, which is the main entry point for
22+
creating new instances of this type. It initializes the `idx` field to 0, which is the
23+
default value for this field.
4124
42-
Arguments:
43-
z ([type]): [description]
25+
Examples:
26+
>>> oracle = MyOracle()
27+
>>> assert oracle.idx == 0
28+
"""
29+
self.idx = 0
4430

45-
Returns:
46-
[type]: [description]
31+
def assess_optim(self, xc, gamma: float):
4732
"""
48-
x, y = z
33+
The function assess_optim assesses feasibility and optimality of a given point based on a specified
34+
gamma value.
35+
36+
:param xc: The `xc` parameter in the `assess_optim` method appears to represent a point in a
37+
two-dimensional space, as it is being unpacked into `x` and `y` coordinates
38+
:param gamma: Gamma is a parameter used in the `assess_optim` method. It is a float value that is
39+
compared with the sum of `x` and `y` in the objective function. The method returns different values
40+
based on the comparison of `gamma` with the sum of `x` and
41+
:type gamma: float
42+
:return: The `assess_optim` method returns a tuple containing two elements. The first element is a
43+
tuple containing an array `[-1.0, -1.0]` and either the value of `fj` (if `fj > 0.0`) or `0.0` (if
44+
`fj <= 0.0`). The second element of the tuple is
45+
"""
46+
x, y = xc
47+
f0 = x + y
4948

5049
for _ in range(3):
5150
self.idx = (self.idx + 1) % 3 # round robin
52-
if (fj := self.fns[self.idx](x, y)) > 0:
53-
return self.grads[self.idx](), fj
54-
return None
55-
56-
57-
class MyOracle(OracleOptim):
58-
helper = MyOracleFeas()
59-
60-
def assess_optim(self, z, gamma: float):
61-
"""[summary]
62-
63-
Arguments:
64-
z ([type]): [description]
65-
gamma (float): the best-so-far optimal value
66-
67-
Returns:
68-
[type]: [description]
69-
"""
70-
self.helper.gamma = gamma
71-
if cut := self.helper.assess_feas(z):
72-
return cut, None
73-
x, y = z
74-
# objective: maximize x + y
75-
return (np.array([-1.0, -1.0]), 0.0), x + y
7651

52+
if self.idx == 0:
53+
fj = f0 - 3.0
54+
elif self.idx == 1:
55+
fj = -x + y + 1.0
56+
elif self.idx == 2:
57+
fj = gamma - f0
58+
else:
59+
raise ValueError("Unexpected index value")
7760

78-
class MyOracleFeas2(OracleFeas):
79-
gamma = 0.0
61+
if fj > 0.0:
62+
if self.idx == 0:
63+
return ((np.array([1.0, 1.0]), fj), None)
64+
elif self.idx == 1:
65+
return ((np.array([-1.0, 1.0]), fj), None)
66+
elif self.idx == 2:
67+
return ((np.array([-1.0, -1.0]), fj), None)
8068

81-
# constraint 1: x + y <= 3
82-
def fn1(self, x, y):
83-
return x + y - 3
84-
85-
# constraint 2: x - y >= 1
86-
def fn2(self, x, y):
87-
return -x + y + 1
88-
89-
def fn0(self, x, y):
90-
return self.gamma - (x + y)
91-
92-
def grad1(self):
93-
return np.array([1.0, 1.0])
94-
95-
def grad2(self):
96-
return np.array([-1.0, 1.0])
97-
98-
def grad0(self):
99-
return np.array([-1.0, -1.0])
100-
101-
def __init__(self):
102-
self.fns = (self.fn1, self.fn2, self.fn0)
103-
self.grads = (self.grad1, self.grad2, self.grad0)
104-
105-
def assess_feas(self, z):
106-
"""[summary]
107-
108-
Arguments:
109-
z ([type]): [description]
110-
111-
Returns:
112-
[type]: [description]
113-
"""
114-
x, y = z
115-
for idx in range(3):
116-
if (fj := self.fns[idx](x, y)) > 0:
117-
return self.grads[idx](), fj
118-
return None
69+
gamma = f0
70+
return ((np.array([-1.0, -1.0]), 0.0), gamma)
11971

12072

12173
class MyOracle2(OracleOptim):
122-
helper = MyOracleFeas2()
74+
"""
75+
This Python class `MyOracle2` contains a method `assess_optim` that assesses optimization based on
76+
given parameters and returns specific values accordingly.
77+
"""
12378

124-
def assess_optim(self, z, gamma: float):
125-
"""[summary]
126-
127-
Arguments:
128-
z ([type]): [description]
129-
gamma (float): the best-so-far optimal value
130-
131-
Returns:
132-
[type]: [description]
79+
def assess_optim(self, xc, gamma: float):
80+
"""
81+
The function assess_optim assesses feasibility and optimality of a given point based on a specified
82+
gamma value.
83+
84+
:param xc: The `xc` parameter in the `assess_optim` method appears to represent a point in a
85+
two-dimensional space, as it is being unpacked into `x` and `y` coordinates
86+
:param gamma: Gamma is a parameter used in the `assess_optim` method. It is a float value that is
87+
compared with the sum of `x` and `y` in the objective function. The method returns different values
88+
based on the comparison of `gamma` with the sum of `x` and
89+
:type gamma: float
90+
:return: The `assess_optim` method returns a tuple containing two elements. The first element is a
91+
tuple containing an array `[-1.0, -1.0]` and either the value of `fj` (if `fj > 0.0`) or `0.0` (if
92+
`fj <= 0.0`). The second element of the tuple is
13393
"""
134-
self.helper.gamma = gamma
135-
if cut := self.helper.assess_feas(z):
136-
return cut, None
137-
x, y = z
138-
# objective: maximize x + y
139-
return (np.array([-1.0, -1.0]), 0.0), x + y
94+
x, y = xc
95+
f0 = x + y
96+
if (fj := f0 - 3.0) > 0.0:
97+
return ((np.array([1.0, 1.0]), fj), None)
98+
if (fj := -x + y + 1.0) > 0.0:
99+
return ((np.array([-1.0, 1.0]), fj), None)
100+
if (fj := gamma - f0) > 0.0:
101+
return ((np.array([-1.0, -1.0]), fj), None)
102+
return ((np.array([-1.0, -1.0]), 0.0), f0)
140103

141104

142105
def run_example1(omega):
@@ -150,7 +113,7 @@ def run_example1(omega):
150113

151114

152115
def test_bm_with_round_robin(benchmark):
153-
num_iters = benchmark(run_example1, MyOracle)
116+
num_iters = benchmark(run_example1, MyOracle1)
154117
assert num_iters == 25
155118

156119

tests/test_quasicvx2.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class MyQuasicvxOracle(OracleOptim):
1818
optimality of a given point based on constraints and an objective function.
1919
"""
2020

21-
idx = 0
21+
idx = 0 # for round robin
2222

2323
def assess_optim(self, xc, gamma: float):
2424
"""
@@ -37,9 +37,9 @@ def assess_optim(self, xc, gamma: float):
3737
"""
3838
x, y = xc
3939

40-
for _ in range(4):
40+
for _ in range(3):
4141
self.idx += 1
42-
if self.idx == 4:
42+
if self.idx == 3:
4343
self.idx = 0
4444

4545
if self.idx == 0:
@@ -55,14 +55,14 @@ def assess_optim(self, xc, gamma: float):
5555
# constraint 3: x > 0
5656
if x <= 0.0:
5757
return (np.array([-1.0, 0.0]), -x), None
58-
elif self.idx == 3:
59-
# objective: minimize -sqrt(x) / y
60-
tmp2 = math.sqrt(x)
61-
if (fj := -tmp2 - gamma * y) > 0.0: # infeasible
62-
return (np.array([-0.5 / tmp2, -gamma]), fj), None
58+
59+
# objective: minimize -sqrt(x) / y
60+
tmp2 = math.sqrt(x)
61+
if (fj := -tmp2 - gamma * y) > 0.0: # infeasible
62+
return (np.array([-0.5 / tmp2, -gamma]), fj), None
6363

6464
gamma = -tmp2 / y
65-
return (np.array([-0.5 / tmp2, -gamma]), 0), gamma
65+
return (np.array([-0.5 / tmp2, -gamma]), 0.0), -tmp2 / y
6666

6767

6868
def test_case_feasible():

0 commit comments

Comments
 (0)