Skip to content

Commit

Permalink
add quantile as threshold for regression sdp + monte carlo search for…
Browse files Browse the repository at this point in the history
… max rules + counterfactual explanations + fixed minor issues (python 3.10, lximp)
  • Loading branch information
salimamoukou committed Aug 31, 2022
1 parent e210f98 commit a2228ca
Show file tree
Hide file tree
Showing 14 changed files with 45,360 additions and 8,471 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/build_publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ jobs:
os: [ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2

- name: Build wheels
uses: pypa/[email protected]
env:
CIBW_BUILD: "cp37-* cp38-* cp39-*"

CIBW_BUILD: "cp37-* cp38-* cp39-* cp310-*"
- uses: actions/upload-artifact@v2
with:
path: ./wheelhouse/*.whl
Expand Down
13 changes: 4 additions & 9 deletions acv_app/sdp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@ def write_pg(x_train, x_test, y_train, y_test, acvtree):
@st.cache(allow_output_mutation=True)
def compute_sdp(nb, x_train, y_train, x_test, y_test, pi_level, t):
sufficient_coal, sdp_coal, sdp_global = acvtree.sufficient_expl_rf(x_test[:nb], y_test[:nb], x_train, y_train,
stop=False, pi_level=pi_level,
t=t)
for i in range(len(sufficient_coal)):
sufficient_coal[i].pop(0)
sdp_coal[i].pop(0)
stop=False, pi_level=pi_level)

return sufficient_coal, sdp_coal, sdp_global

@st.cache(allow_output_mutation=True)
def compute_sdp_rule(obs, x_train_np, y_train_np, x_test_np, y_test_np, t, S):
sdp, rules = acvtree.compute_sdp_rule(x_test_np[obs:obs+1], y_test_np[obs:obs+1],
x_train_np, y_train_np, S=[S], t=t)
x_train_np, y_train_np, S=[S])
rule = rules[0]
columns = [x_train.columns[i] for i in range(x_train.shape[1])]
rule_string = ['{} <= {} <= {}'.format(rule[i, 0] if rule[i, 0] > -1e+10 else -np.inf, columns[i],
Expand All @@ -67,7 +64,7 @@ def compute_sdp_rule(obs, x_train_np, y_train_np, x_test_np, y_test_np, t, S):
@st.cache(allow_output_mutation=True)
def compute_sdp_maxrule(obs, x_train_np, y_train_np, x_test_np, y_test_np, t, S, pi):
sdp, rules, sdp_all, rules_data, w = acvtree.compute_sdp_maxrules(x_test_np[obs:obs + 1], y_test_np[obs:obs + 1],
x_train_np, y_train_np, S=[S], t=t, pi_level=pi)
x_train_np, y_train_np, S=[S], pi_level=pi)

acvtree.fit_global_rules(x_train_np, y_train_np, rules, [S])

Expand Down Expand Up @@ -180,7 +177,6 @@ def bar_legacy(shap_values, features=None, feature_names=None, max_display=None,
sufficient_coal, sdp_coal, sdp_global = compute_sdp(nb, x_train.values.astype(np.double),
y_train.astype(np.double), x_test.values.astype(np.double),
y_test.astype(np.double), pi_level=pi_level, t=t)

sufficient_coal_names = transform_scoal_to_col(sufficient_coal, x_train.columns)
# explantions_load_state.text("SDP explanation Done!")

Expand All @@ -197,7 +193,6 @@ def bar_legacy(shap_values, features=None, feature_names=None, max_display=None,
'SDP': sdp_coal[idx]}

sufficient_coal_df = pd.DataFrame(sufficient_coal_df)
# print(sufficient_coal_df.head())
st.dataframe(sufficient_coal_df, 6000, 6000)

with col2:
Expand Down
1,465 changes: 1,268 additions & 197 deletions acv_explainers/acv_agnosticX.py

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions acv_explainers/base_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,9 @@ def __init__(self, model, data=None, data_missing=None, cache=False, cache_norma
# data_leaves_trees = -np.ones(shape=(len(self.leaves_nb), np.max(self.leaves_nb), self.data.shape[0], self.data.shape[1]), dtype=np.int)
for i in range(len(self.leaves_nb)):
leaf_idx_trees[i, :self.leaves_nb[i]] = np.array(self.leaf_idx_trees[i], dtype=np.int)

if self.data.shape[1] == 1:
self.partition_leaves_trees[i] = np.expand_dims(self.partition_leaves_trees[i], axis=1)
partition_leaves_trees[i, :self.leaves_nb[i]] = np.array(self.partition_leaves_trees[i])
# data_leaves_trees[i, :self.leaves_nb[i]] = np.array(self.data_leaves_trees[i], dtype=np.int)

Expand Down
359 changes: 359 additions & 0 deletions acv_explainers/counterfactual_rules.py

Large diffs are not rendered by default.

Loading

0 comments on commit a2228ca

Please sign in to comment.