From 8dbd2226c4d4fdab8de8c3e795818da5733a8209 Mon Sep 17 00:00:00 2001 From: Szubie Date: Fri, 30 Aug 2019 22:10:56 +0100 Subject: [PATCH 1/2] Add multilabel supervision support --- ivis/ivis.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/ivis/ivis.py b/ivis/ivis.py index 82e4438..304042e 100644 --- a/ivis/ivis.py +++ b/ivis/ivis.py @@ -173,11 +173,19 @@ def _fit(self, X, Y=None, shuffle_mode=True): if not is_multiclass(self.supervision_metric): if not is_hinge(self.supervision_metric): # Binary logistic classifier - supervised_output = Dense(1, activation='sigmoid', + if len(Y.shape) > 1: + n_classes = Y.shape[-1] + else: + n_classes = 1 + supervised_output = Dense(n_classes, activation='sigmoid', name='supervised')(anchor_embedding) else: # Binary Linear SVM output - supervised_output = Dense(1, activation='linear', + if len(Y.shape) > 1: + n_classes = Y.shape[-1] + else: + n_classes = 1 + supervised_output = Dense(n_classes, activation='linear', name='supervised', kernel_regularizer=regularizers.l2())(anchor_embedding) else: @@ -193,7 +201,11 @@ def _fit(self, X, Y=None, shuffle_mode=True): kernel_regularizer=regularizers.l2())(anchor_embedding) else: # Regression - supervised_output = Dense(1, activation='linear', + if len(Y.shape) > 1: + n_classes = Y.shape[-1] + else: + n_classes = 1 + supervised_output = Dense(n_classes, activation='linear', name='supervised')(anchor_embedding) final_network = Model(inputs=self.model_.inputs, From 4b475de2249b9f4c0fd9bc4d4c6ef29d5acaf8f6 Mon Sep 17 00:00:00 2001 From: idroz Date: Thu, 5 Sep 2019 13:37:35 +0100 Subject: [PATCH 2/2] version bump --- R-package/DESCRIPTION | 2 +- ivis/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index faddc0d..40b8737 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -1,6 +1,6 @@ Package: ivis Title: Artificial neural network-driven visualization of high-dimensional data using triplets. -Version: 1.4.0 +Version: 1.4.1 Authors@R: c(person("Benjamin", "Szubert", email = "bszubert@beringresearch.com", role = c("aut", "cre")), person("Ignat", "Drozdov", email = "idrozdov@beringresearch.com", role = c("aut")), person("Kevin", "Rue-Albrecht", role = "ctb", email = "kevin.rue-albrecht@kennedy.ox.ac.uk", comment = c(ORCID = "0000-0003-3899-3872"))) diff --git a/ivis/version.py b/ivis/version.py index b2e8177..e5909ef 100644 --- a/ivis/version.py +++ b/ivis/version.py @@ -1 +1 @@ -VERSION = '1.4.0' +VERSION = '1.4.1'