Skip to content

Commit 5a1d61b

Browse files
authored
Merge pull request #5251 from janezd/table-join
Table: Add methods 'join' and 'with_column'
2 parents b7d5d48 + 5429e1c commit 5a1d61b

File tree

3 files changed

+235
-15
lines changed

3 files changed

+235
-15
lines changed

Orange/data/table.py

+95-14
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,49 @@ def __repr__(self):
897897

898898
@classmethod
899899
def concatenate(cls, tables, axis=0):
900-
"""Concatenate tables into a new table"""
900+
"""
901+
Concatenate tables into a new table, either vertically or horizontally.
902+
903+
If axis=0 (vertical concatenate), all tables must have the same domain.
904+
905+
If axis=1 (horizontal),
906+
- all variable names must be unique.
907+
- ids are copied from the first table.
908+
- weights are copied from the first table in which they are defined.
909+
- the dictionary of table's attributes are merged. If the same attribute
910+
appears in multiple dictionaries, the earlier are used.
911+
912+
Args:
913+
tables (Table): tables to be joined
914+
915+
Returns:
916+
table (Table)
917+
"""
918+
if axis not in (0, 1):
919+
raise ValueError("invalid axis")
920+
if not tables:
921+
raise ValueError('need at least one table to concatenate')
922+
923+
if len(tables) == 1:
924+
return tables[0].copy()
925+
926+
if axis == 0:
927+
conc = cls._concatenate_vertical(tables)
928+
else:
929+
conc = cls._concatenate_horizontal(tables)
930+
931+
# TODO: Add attributes = {} to __init__
932+
conc.attributes = getattr(conc, "attributes", {})
933+
for table in reversed(tables):
934+
conc.attributes.update(table.attributes)
935+
936+
names = [table.name for table in tables if table.name != "untitled"]
937+
if names:
938+
conc.name = names[0]
939+
return conc
940+
941+
@classmethod
942+
def _concatenate_vertical(cls, tables):
901943
def vstack(arrs):
902944
return [np, sp][any(sp.issparse(arr) for arr in arrs)].vstack(arrs)
903945

@@ -915,12 +957,6 @@ def merge1d(arrs):
915957
def collect(attr):
916958
return [getattr(arr, attr) for arr in tables]
917959

918-
if axis == 1:
919-
raise ValueError("concatenate no longer supports axis 1")
920-
if not tables:
921-
raise ValueError('need at least one table to concatenate')
922-
if len(tables) == 1:
923-
return tables[0].copy()
924960
domain = tables[0].domain
925961
if any(table.domain != domain for table in tables):
926962
raise ValueError('concatenated tables must have the same domain')
@@ -933,15 +969,60 @@ def collect(attr):
933969
merge1d(collect("W"))
934970
)
935971
conc.ids = np.hstack([t.ids for t in tables])
936-
names = [table.name for table in tables if table.name != "untitled"]
937-
if names:
938-
conc.name = names[0]
939-
# TODO: Add attributes = {} to __init__
940-
conc.attributes = getattr(conc, "attributes", {})
941-
for table in reversed(tables):
942-
conc.attributes.update(table.attributes)
943972
return conc
944973

974+
@classmethod
975+
def _concatenate_horizontal(cls, tables):
976+
"""
977+
"""
978+
def all_of(objs, names):
979+
return (tuple(getattr(obj, name) for obj in objs)
980+
for name in names)
981+
982+
def stack(arrs):
983+
non_empty = tuple(arr if arr.ndim == 2 else arr[:, np.newaxis]
984+
for arr in arrs
985+
if arr is not None and arr.size > 0)
986+
return np.hstack(non_empty) if non_empty else None
987+
988+
doms, Ws = all_of(tables, ("domain", "W"))
989+
Xs, Ys, Ms = map(stack, all_of(tables, ("X", "Y", "metas")))
990+
# pylint: disable=undefined-loop-variable
991+
for W in Ws:
992+
if W.size:
993+
break
994+
995+
parts = all_of(doms, ("attributes", "class_vars", "metas"))
996+
domain = Domain(*(tuple(chain(*lst)) for lst in parts))
997+
return cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids)
998+
999+
def add_column(self, variable, data, to_metas=None):
1000+
"""
1001+
Create a new table with an additional column
1002+
1003+
Column's name must be unique
1004+
1005+
Args:
1006+
variable (Variable): variable for the new column
1007+
data (np.ndarray): data for the new column
1008+
to_metas (bool, optional): if `True` the column is added as meta
1009+
column. Otherwise, primitive variables are added to attributes
1010+
and non-primitive to metas.
1011+
1012+
Returns:
1013+
table (Table): a new table with the additional column
1014+
"""
1015+
dom = self.domain
1016+
attrs, classes, metas = dom.attributes, dom.class_vars, dom.metas
1017+
if to_metas or not variable.is_primitive():
1018+
metas += (variable, )
1019+
else:
1020+
attrs += (variable, )
1021+
domain = Domain(attrs, classes, metas)
1022+
new_table = self.transform(domain)
1023+
new_table.get_column_view(variable)[0][:] = data
1024+
return new_table
1025+
9451026
def is_view(self):
9461027
"""
9471028
Return `True` if all arrays represent a view referring to another table

Orange/data/tests/test_table.py

+140
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,146 @@ def test_from_numpy_sparse(self):
108108
t = Table.from_numpy(domain, sp.bsr_matrix(x))
109109
self.assertTrue(sp.isspmatrix_csr(t.X))
110110

111+
@staticmethod
112+
def _new_table(attrs, classes, metas, s):
113+
def nz(x): # pylint: disable=invalid-name
114+
return x if x.size else np.empty((5, 0))
115+
116+
domain = Domain(attrs, classes, metas)
117+
X = np.arange(s, s + len(attrs) * 5).reshape(5, -1)
118+
Y = np.arange(100 + s, 100 + s + len(classes) * 5)
119+
if len(classes) > 1:
120+
Y = Y.reshape(5, -1)
121+
M = np.arange(200 + s, 200 + s + len(metas) * 5).reshape(5, -1)
122+
return Table.from_numpy(domain, nz(X), nz(Y), nz(M))
123+
124+
def test_concatenate_horizontal(self):
125+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
126+
127+
# Common case; one class, no empty's
128+
tab1 = self._new_table((a, b), (c, ), (d, ), 0)
129+
tab2 = self._new_table((e, ), (), (f, g), 1000)
130+
joined = Table.concatenate((tab1, tab2), axis=1)
131+
domain = joined.domain
132+
self.assertEqual(domain.attributes, (a, b, e))
133+
self.assertEqual(domain.class_vars, (c, ))
134+
self.assertEqual(domain.metas, (d, f, g))
135+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
136+
np.testing.assert_equal(joined.Y, tab1.Y)
137+
np.testing.assert_equal(joined.metas, np.hstack((tab1.metas, tab2.metas)))
138+
139+
# One part of one table is empty
140+
tab1 = self._new_table((a, b), (), (), 0)
141+
tab2 = self._new_table((), (), (c, ), 1000)
142+
joined = Table.concatenate((tab1, tab2), axis=1)
143+
domain = joined.domain
144+
self.assertEqual(domain.attributes, (a, b))
145+
self.assertEqual(domain.class_vars, ())
146+
self.assertEqual(domain.metas, (c, ))
147+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
148+
np.testing.assert_equal(joined.metas, np.hstack((tab1.metas, tab2.metas)))
149+
150+
# Multiple classes, two empty parts are merged
151+
tab1 = self._new_table((a, b), (c, ), (), 0)
152+
tab2 = self._new_table((), (d, ), (), 1000)
153+
joined = Table.concatenate((tab1, tab2), axis=1)
154+
domain = joined.domain
155+
self.assertEqual(domain.attributes, (a, b))
156+
self.assertEqual(domain.class_vars, (c, d))
157+
self.assertEqual(domain.metas, ())
158+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
159+
np.testing.assert_equal(joined.Y, np.vstack((tab1.Y, tab2.Y)).T)
160+
161+
# Merging of attributes and selection of weights
162+
tab1 = self._new_table((a, b), (c, ), (), 0)
163+
tab1.attributes = dict(a=5, b=7)
164+
tab2 = self._new_table((d, ), (e, ), (), 1000)
165+
tab2.W = np.arange(5)
166+
tab3 = self._new_table((f, g), (), (), 2000)
167+
tab3.attributes = dict(a=1, c=4)
168+
tab3.W = np.arange(5, 10)
169+
joined = Table.concatenate((tab1, tab2, tab3), axis=1)
170+
domain = joined.domain
171+
self.assertEqual(domain.attributes, (a, b, d, f, g))
172+
self.assertEqual(domain.class_vars, (c, e))
173+
self.assertEqual(domain.metas, ())
174+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X, tab3.X)))
175+
np.testing.assert_equal(joined.Y, np.vstack((tab1.Y, tab2.Y)).T)
176+
self.assertEqual(joined.attributes, dict(a=5, b=7, c=4))
177+
np.testing.assert_equal(joined.ids, tab1.ids)
178+
np.testing.assert_equal(joined.W, tab2.W)
179+
180+
# Raise an exception when no tables are given
181+
self.assertRaises(ValueError, Table.concatenate, (), axis=1)
182+
183+
def test_concatenate_invalid_axis(self):
184+
self.assertRaises(ValueError, Table.concatenate, (), axis=2)
185+
186+
def test_concatenate_names(self):
187+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
188+
189+
tab1 = self._new_table((a, ), (c, ), (d, ), 0)
190+
tab2 = self._new_table((e, ), (), (f, g), 1000)
191+
tab3 = self._new_table((b, ), (), (), 1000)
192+
tab2.name = "tab2"
193+
tab3.name = "tab3"
194+
195+
joined = Table.concatenate((tab1, tab2, tab3), axis=1)
196+
self.assertEqual(joined.name, "tab2")
197+
198+
def test_with_column(self):
199+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
200+
col = np.arange(9, 14)
201+
colr = col.reshape(5, -1)
202+
tab = self._new_table((a, b, c), (d, ), (e, f), 0)
203+
204+
# Add to attributes
205+
tabw = tab.add_column(g, np.arange(9, 14))
206+
self.assertEqual(tabw.domain.attributes, (a, b, c, g))
207+
np.testing.assert_equal(tabw.X, np.hstack((tab.X, colr)))
208+
np.testing.assert_equal(tabw.Y, tab.Y)
209+
np.testing.assert_equal(tabw.metas, tab.metas)
210+
211+
# Add to metas
212+
tabw = tab.add_column(g, np.arange(9, 14), to_metas=True)
213+
self.assertEqual(tabw.domain.metas, (e, f, g))
214+
np.testing.assert_equal(tabw.X, tab.X)
215+
np.testing.assert_equal(tabw.Y, tab.Y)
216+
np.testing.assert_equal(tabw.metas, np.hstack((tab.metas, colr)))
217+
218+
# Add to empty attributes
219+
tab = self._new_table((), (d, ), (e, f), 0)
220+
tabw = tab.add_column(g, np.arange(9, 14))
221+
self.assertEqual(tabw.domain.attributes, (g, ))
222+
np.testing.assert_equal(tabw.X, colr)
223+
np.testing.assert_equal(tabw.Y, tab.Y)
224+
np.testing.assert_equal(tabw.metas, tab.metas)
225+
226+
# Add to empty metas
227+
tab = self._new_table((a, b, c), (d, ), (), 0)
228+
tabw = tab.add_column(g, np.arange(9, 14), to_metas=True)
229+
self.assertEqual(tabw.domain.metas, (g, ))
230+
np.testing.assert_equal(tabw.X, tab.X)
231+
np.testing.assert_equal(tabw.Y, tab.Y)
232+
np.testing.assert_equal(tabw.metas, colr)
233+
234+
# Pass values as a list
235+
tab = self._new_table((a, ), (d, ), (e, f), 0)
236+
tabw = tab.add_column(g, [4, 2, -1, 2, 5])
237+
self.assertEqual(tabw.domain.attributes, (a, g))
238+
np.testing.assert_equal(
239+
tabw.X, np.array([[0, 1, 2, 3, 4], [4, 2, -1, 2, 5]]).T)
240+
241+
# Add non-primitives as metas; join `float` and `object` to `object`
242+
tab = self._new_table((a, ), (d, ), (e, f), 0)
243+
t = StringVariable("t")
244+
tabw = tab.add_column(t, list("abcde"))
245+
self.assertEqual(tabw.domain.attributes, (a, ))
246+
self.assertEqual(tabw.domain.metas, (e, f, t))
247+
np.testing.assert_equal(
248+
tabw.metas,
249+
np.hstack((tab.metas, np.array(list("abcde")).reshape(5, -1))))
250+
111251

112252
class TestTableFilters(unittest.TestCase):
113253
def setUp(self):

Orange/tests/test_table.py

-1
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def test_concatenate_exceptions(self):
541541
iris = data.Table("iris")
542542

543543
self.assertRaises(ValueError, data.Table.concatenate, [])
544-
self.assertRaises(ValueError, data.Table.concatenate, [zoo], axis=1)
545544
self.assertRaises(ValueError, data.Table.concatenate, [zoo, iris])
546545

547546
def test_concatenate_sparse(self):

0 commit comments

Comments
 (0)