Skip to content

Commit f5bce40

Browse files
committed
Table: Add methods 'join' and 'with_column'
1 parent 0568368 commit f5bce40

File tree

3 files changed

+216
-4
lines changed

3 files changed

+216
-4
lines changed

Orange/data/table.py

+88-3
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,33 @@ def __repr__(self):
897897

898898
@classmethod
899899
def concatenate(cls, tables, axis=0):
900-
"""Concatenate tables into a new table"""
900+
"""
901+
Concatenate tables into a new table, either horizontally or vertically.
902+
903+
If axis=0 (horizontal concatenate), all tables must have the same domain.
904+
905+
If axis=1 (vertical),
906+
- all variable names must be unique.
907+
- ids are copied from the first table.
908+
- weights are copied from the first table in which they are defined.
909+
- the dictionary of table's attributes are merged. If the same attribute
910+
appears in multiple dictionaries, the earlier are used.
911+
912+
Args:
913+
tables (Table): tables to be joined
914+
915+
Returns:
916+
table (Table)
917+
"""
918+
if axis == 0:
919+
return cls._concatenate_vertical(tables)
920+
elif axis == 1:
921+
return cls._concatenate_horizontal(tables)
922+
else:
923+
raise ValueError("invalid axis")
924+
925+
@classmethod
926+
def _concatenate_vertical(cls, tables):
901927
def vstack(arrs):
902928
return [np, sp][any(sp.issparse(arr) for arr in arrs)].vstack(arrs)
903929

@@ -915,8 +941,6 @@ def merge1d(arrs):
915941
def collect(attr):
916942
return [getattr(arr, attr) for arr in tables]
917943

918-
if axis == 1:
919-
raise ValueError("concatenate no longer supports axis 1")
920944
if not tables:
921945
raise ValueError('need at least one table to concatenate')
922946
if len(tables) == 1:
@@ -942,6 +966,67 @@ def collect(attr):
942966
conc.attributes.update(table.attributes)
943967
return conc
944968

969+
@classmethod
970+
def _concatenate_horizontal(cls, tables):
971+
"""
972+
"""
973+
if not tables:
974+
raise ValueError('need at least one table to join')
975+
976+
def all_of(objs, names):
977+
return (tuple(getattr(obj, name) for obj in objs)
978+
for name in names)
979+
980+
def stack(arrs):
981+
non_empty = tuple(arr if arr.ndim == 2 else arr[:, np.newaxis]
982+
for arr in arrs
983+
if arr is not None and arr.size > 0)
984+
return np.hstack(non_empty) if non_empty else None
985+
986+
doms, Ws, table_attrss = all_of(tables, ("domain", "W", "attributes"))
987+
Xs, Ys, Ms = map(stack, all_of(tables, ("X", "Y", "metas")))
988+
if Ys is not None and Ys.shape[0] == 1:
989+
Ys = Ys.flatten()
990+
# pylint: disable=undefined-loop-variable
991+
for W in Ws:
992+
if W.size:
993+
break
994+
995+
parts = all_of(doms, ("attributes", "class_vars", "metas"))
996+
domain = Domain(*(tuple(chain(*lst)) for lst in parts))
997+
table = cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids)
998+
for ta in reversed(table_attrss):
999+
table.attributes.update(ta)
1000+
1001+
return table
1002+
1003+
def add_column(self, variable, data, to_metas=None):
1004+
"""
1005+
Create a new table with an additional column
1006+
1007+
Column's name must be unique
1008+
1009+
Args:
1010+
variable (Variable): variable for the new column
1011+
data (np.ndarray): data for the new column
1012+
to_metas (bool, optional): if `True` the column is added as meta
1013+
column. Otherwise, primitive variables are added to attributes
1014+
and non-primitive to metas.
1015+
1016+
Returns:
1017+
table (Table): a new table with the additional column
1018+
"""
1019+
dom = self.domain
1020+
attrs, classes, metas = dom.attributes, dom.class_vars, dom.metas
1021+
if to_metas or not variable.is_primitive():
1022+
metas += (variable, )
1023+
else:
1024+
attrs += (variable, )
1025+
domain = Domain(attrs, classes, metas)
1026+
new_table = self.transform(domain)
1027+
new_table.get_column_view(variable)[0][:] = data
1028+
return new_table
1029+
9451030
def is_view(self):
9461031
"""
9471032
Return `True` if all arrays represent a view referring to another table

Orange/data/tests/test_table.py

+128
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,134 @@ def test_from_numpy_sparse(self):
108108
t = Table.from_numpy(domain, sp.bsr_matrix(x))
109109
self.assertTrue(sp.isspmatrix_csr(t.X))
110110

111+
@staticmethod
112+
def _new_table(attrs, classes, metas, s):
113+
def nz(x): # pylint: disable=invalid-name
114+
return x if x.size else np.empty((5, 0))
115+
116+
domain = Domain(attrs, classes, metas)
117+
X = np.arange(s, s + len(attrs) * 5).reshape(5, -1)
118+
Y = np.arange(100 + s, 100 + s + len(classes) * 5)
119+
if len(classes) > 1:
120+
Y = Y.reshape(5, -1)
121+
M = np.arange(200 + s, 200 + s + len(metas) * 5).reshape(5, -1)
122+
return Table.from_numpy(domain, nz(X), nz(Y), nz(M))
123+
124+
def test_concatenate_horizontal(self):
125+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
126+
127+
# Common case; one class, no empty's
128+
tab1 = self._new_table((a, b), (c, ), (d, ), 0)
129+
tab2 = self._new_table((e, ), (), (f, g), 1000)
130+
joined = Table.concatenate((tab1, tab2), axis=1)
131+
domain = joined.domain
132+
self.assertEqual(domain.attributes, (a, b, e))
133+
self.assertEqual(domain.class_vars, (c, ))
134+
self.assertEqual(domain.metas, (d, f, g))
135+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
136+
np.testing.assert_equal(joined.Y, tab1.Y)
137+
np.testing.assert_equal(joined.metas, np.hstack((tab1.metas, tab2.metas)))
138+
139+
# One part of one table is empty
140+
tab1 = self._new_table((a, b), (), (), 0)
141+
tab2 = self._new_table((), (), (c, ), 1000)
142+
joined = Table.concatenate((tab1, tab2), axis=1)
143+
domain = joined.domain
144+
self.assertEqual(domain.attributes, (a, b))
145+
self.assertEqual(domain.class_vars, ())
146+
self.assertEqual(domain.metas, (c, ))
147+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
148+
np.testing.assert_equal(joined.metas, np.hstack((tab1.metas, tab2.metas)))
149+
150+
# Multiple classes, two empty parts are merged
151+
tab1 = self._new_table((a, b), (c, ), (), 0)
152+
tab2 = self._new_table((), (d, ), (), 1000)
153+
joined = Table.concatenate((tab1, tab2), axis=1)
154+
domain = joined.domain
155+
self.assertEqual(domain.attributes, (a, b))
156+
self.assertEqual(domain.class_vars, (c, d))
157+
self.assertEqual(domain.metas, ())
158+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X)))
159+
np.testing.assert_equal(joined.Y, np.vstack((tab1.Y, tab2.Y)).T)
160+
161+
# Merging of attributes and selection of weights
162+
tab1 = self._new_table((a, b), (c, ), (), 0)
163+
tab1.attributes = dict(a=5, b=7)
164+
tab2 = self._new_table((d, ), (e, ), (), 1000)
165+
tab2.W = np.arange(5)
166+
tab3 = self._new_table((f, g), (), (), 2000)
167+
tab3.attributes = dict(a=1, c=4)
168+
tab3.W = np.arange(5, 10)
169+
joined = Table.concatenate((tab1, tab2, tab3), axis=1)
170+
domain = joined.domain
171+
self.assertEqual(domain.attributes, (a, b, d, f, g))
172+
self.assertEqual(domain.class_vars, (c, e))
173+
self.assertEqual(domain.metas, ())
174+
np.testing.assert_equal(joined.X, np.hstack((tab1.X, tab2.X, tab3.X)))
175+
np.testing.assert_equal(joined.Y, np.vstack((tab1.Y, tab2.Y)).T)
176+
self.assertEqual(joined.attributes, dict(a=5, b=7, c=4))
177+
np.testing.assert_equal(joined.ids, tab1.ids)
178+
np.testing.assert_equal(joined.W, tab2.W)
179+
180+
# Raise an exception when no tables are given
181+
self.assertRaises(ValueError, Table.concatenate, (), axis=1)
182+
183+
def test_concatenate_invalid_axis(self):
184+
self.assertRaises(ValueError, Table.concatenate, (), axis=2)
185+
186+
def test_with_column(self):
187+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
188+
col = np.arange(9, 14)
189+
colr = col.reshape(5, -1)
190+
tab = self._new_table((a, b, c), (d, ), (e, f), 0)
191+
192+
# Add to attributes
193+
tabw = tab.add_column(g, np.arange(9, 14))
194+
self.assertEqual(tabw.domain.attributes, (a, b, c, g))
195+
np.testing.assert_equal(tabw.X, np.hstack((tab.X, colr)))
196+
np.testing.assert_equal(tabw.Y, tab.Y)
197+
np.testing.assert_equal(tabw.metas, tab.metas)
198+
199+
# Add to metas
200+
tabw = tab.add_column(g, np.arange(9, 14), to_metas=True)
201+
self.assertEqual(tabw.domain.metas, (e, f, g))
202+
np.testing.assert_equal(tabw.X, tab.X)
203+
np.testing.assert_equal(tabw.Y, tab.Y)
204+
np.testing.assert_equal(tabw.metas, np.hstack((tab.metas, colr)))
205+
206+
# Add to empty attributes
207+
tab = self._new_table((), (d, ), (e, f), 0)
208+
tabw = tab.add_column(g, np.arange(9, 14))
209+
self.assertEqual(tabw.domain.attributes, (g, ))
210+
np.testing.assert_equal(tabw.X, colr)
211+
np.testing.assert_equal(tabw.Y, tab.Y)
212+
np.testing.assert_equal(tabw.metas, tab.metas)
213+
214+
# Add to empty metas
215+
tab = self._new_table((a, b, c), (d, ), (), 0)
216+
tabw = tab.add_column(g, np.arange(9, 14), to_metas=True)
217+
self.assertEqual(tabw.domain.metas, (g, ))
218+
np.testing.assert_equal(tabw.X, tab.X)
219+
np.testing.assert_equal(tabw.Y, tab.Y)
220+
np.testing.assert_equal(tabw.metas, colr)
221+
222+
# Pass values as a list
223+
tab = self._new_table((a, ), (d, ), (e, f), 0)
224+
tabw = tab.add_column(g, [4, 2, -1, 2, 5])
225+
self.assertEqual(tabw.domain.attributes, (a, g))
226+
np.testing.assert_equal(
227+
tabw.X, np.array([[0, 1, 2, 3, 4], [4, 2, -1, 2, 5]]).T)
228+
229+
# Add non-primitives as metas; join `float` and `object` to `object`
230+
tab = self._new_table((a, ), (d, ), (e, f), 0)
231+
t = StringVariable("t")
232+
tabw = tab.add_column(t, list("abcde"))
233+
self.assertEqual(tabw.domain.attributes, (a, ))
234+
self.assertEqual(tabw.domain.metas, (e, f, t))
235+
np.testing.assert_equal(
236+
tabw.metas,
237+
np.hstack((tab.metas, np.array(list("abcde")).reshape(5, -1))))
238+
111239

112240
class TestTableFilters(unittest.TestCase):
113241
def setUp(self):

Orange/tests/test_table.py

-1
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,6 @@ def test_concatenate_exceptions(self):
541541
iris = data.Table("iris")
542542

543543
self.assertRaises(ValueError, data.Table.concatenate, [])
544-
self.assertRaises(ValueError, data.Table.concatenate, [zoo], axis=1)
545544
self.assertRaises(ValueError, data.Table.concatenate, [zoo, iris])
546545

547546
def test_concatenate_sparse(self):

0 commit comments

Comments
 (0)