|
| 1 | +# Copyright 2023 IBM Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
1 | 15 | from typing import Dict, Set
|
2 | 16 |
|
3 | 17 | import numpy as np
|
4 |
| -from mystic.coupler import and_ |
5 |
| -from mystic.penalty import quadratic_equality |
6 |
| -from mystic.solvers import diffev2 |
| 18 | + |
| 19 | +try: |
| 20 | + from mystic.coupler import and_ |
| 21 | + from mystic.penalty import quadratic_equality |
| 22 | + from mystic.solvers import diffev2 |
| 23 | + |
| 24 | + mystic_installed = True |
| 25 | +except ModuleNotFoundError: |
| 26 | + mystic_installed = False |
| 27 | + |
| 28 | + |
| 29 | +def _assert_mystic_installed(): |
| 30 | + assert mystic_installed, """Your Python environment does not have mystic installed. You can install it with |
| 31 | + pip install mystic |
| 32 | +or with |
| 33 | + pip install 'lale[fairness]'""" |
7 | 34 |
|
8 | 35 |
|
9 | 36 | def parse_solver_soln(n_flat, group_mapping):
|
@@ -122,6 +149,8 @@ def obtain_solver_info(
|
122 | 149 |
|
123 | 150 |
|
124 | 151 | def construct_ci_penalty(A, C, n_ci, i):
|
| 152 | + _assert_mystic_installed() |
| 153 | + |
125 | 154 | def condition(x):
|
126 | 155 | reshape_list = []
|
127 | 156 | for _ in range(A):
|
@@ -153,6 +182,8 @@ def create_ci_penalties(n_ci, n_di):
|
153 | 182 |
|
154 | 183 |
|
155 | 184 | def construct_di_penalty(A, C, n_di, F, i):
|
| 185 | + _assert_mystic_installed() |
| 186 | + |
156 | 187 | def condition(x):
|
157 | 188 | reshape_list = []
|
158 | 189 | for _ in range(A):
|
@@ -201,6 +232,7 @@ def create_di_penalties(n_ci, n_di, F):
|
201 | 232 |
|
202 | 233 |
|
203 | 234 | def calc_oversample_soln(o_flat, F, n_ci, n_di):
|
| 235 | + _assert_mystic_installed() |
204 | 236 | # integer constraint
|
205 | 237 | ints = np.round
|
206 | 238 |
|
@@ -235,6 +267,7 @@ def cost(x):
|
235 | 267 |
|
236 | 268 |
|
237 | 269 | def calc_undersample_soln(o_flat, F, n_ci, n_di):
|
| 270 | + _assert_mystic_installed() |
238 | 271 | # integer constraint
|
239 | 272 | ints = np.round
|
240 | 273 |
|
@@ -269,6 +302,7 @@ def cost(x):
|
269 | 302 |
|
270 | 303 |
|
271 | 304 | def calc_mixedsample_soln(o_flat, F, n_ci, n_di):
|
| 305 | + _assert_mystic_installed() |
272 | 306 | # integer constraint
|
273 | 307 | ints = np.round
|
274 | 308 |
|
|
0 commit comments