Skip to content

Commit

Permalink
Merge pull request #45 from beringresearch/multi-label-supervision
Browse files Browse the repository at this point in the history
Multi label supervision
  • Loading branch information
idroz authored Sep 5, 2019
2 parents 8aa67db + 4b475de commit 8a9886f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: ivis
Title: Artificial neural network-driven visualization of high-dimensional data using triplets.
Version: 1.4.0
Version: 1.4.1
Authors@R: c(person("Benjamin", "Szubert", email = "[email protected]", role = c("aut", "cre")),
person("Ignat", "Drozdov", email = "[email protected]", role = c("aut")),
person("Kevin", "Rue-Albrecht", role = "ctb", email = "[email protected]", comment = c(ORCID = "0000-0003-3899-3872")))
Expand Down
18 changes: 15 additions & 3 deletions ivis/ivis.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,11 +173,19 @@ def _fit(self, X, Y=None, shuffle_mode=True):
if not is_multiclass(self.supervision_metric):
if not is_hinge(self.supervision_metric):
# Binary logistic classifier
supervised_output = Dense(1, activation='sigmoid',
if len(Y.shape) > 1:
n_classes = Y.shape[-1]
else:
n_classes = 1
supervised_output = Dense(n_classes, activation='sigmoid',
name='supervised')(anchor_embedding)
else:
# Binary Linear SVM output
supervised_output = Dense(1, activation='linear',
if len(Y.shape) > 1:
n_classes = Y.shape[-1]
else:
n_classes = 1
supervised_output = Dense(n_classes, activation='linear',
name='supervised',
kernel_regularizer=regularizers.l2())(anchor_embedding)
else:
Expand All @@ -193,7 +201,11 @@ def _fit(self, X, Y=None, shuffle_mode=True):
kernel_regularizer=regularizers.l2())(anchor_embedding)
else:
# Regression
supervised_output = Dense(1, activation='linear',
if len(Y.shape) > 1:
n_classes = Y.shape[-1]
else:
n_classes = 1
supervised_output = Dense(n_classes, activation='linear',
name='supervised')(anchor_embedding)

final_network = Model(inputs=self.model_.inputs,
Expand Down
2 changes: 1 addition & 1 deletion ivis/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '1.4.0'
VERSION = '1.4.1'

0 comments on commit 8a9886f

Please sign in to comment.