Skip to content

Commit 8ac409b

Browse files
committed
Table.concatenate: Refactor
1 parent 6ae06fc commit 8ac409b

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

Diff for: Orange/data/table.py

+24-26
Original file line numberDiff line numberDiff line change
@@ -898,11 +898,11 @@ def __repr__(self):
898898
@classmethod
899899
def concatenate(cls, tables, axis=0):
900900
"""
901-
Concatenate tables into a new table, either horizontally or vertically.
901+
Concatenate tables into a new table, either vertically or horizontally.
902902
903-
If axis=0 (horizontal concatenate), all tables must have the same domain.
903+
If axis=0 (vertical concatenate), all tables must have the same domain.
904904
905-
If axis=1 (vertical),
905+
If axis=1 (horizontal),
906906
- all variable names must be unique.
907907
- ids are copied from the first table.
908908
- weights are copied from the first table in which they are defined.
@@ -915,12 +915,28 @@ def concatenate(cls, tables, axis=0):
915915
Returns:
916916
table (Table)
917917
"""
918+
if axis not in (0, 1):
919+
raise ValueError("invalid axis")
920+
if not tables:
921+
raise ValueError('need at least one table to concatenate')
922+
923+
if len(tables) == 1:
924+
return tables[0].copy()
925+
918926
if axis == 0:
919-
return cls._concatenate_vertical(tables)
920-
elif axis == 1:
921-
return cls._concatenate_horizontal(tables)
927+
conc = cls._concatenate_vertical(tables)
922928
else:
923-
raise ValueError("invalid axis")
929+
conc = cls._concatenate_horizontal(tables)
930+
931+
# TODO: Add attributes = {} to __init__
932+
conc.attributes = getattr(conc, "attributes", {})
933+
for table in reversed(tables):
934+
conc.attributes.update(table.attributes)
935+
936+
names = [table.name for table in tables if table.name != "untitled"]
937+
if names:
938+
conc.name = names[0]
939+
return conc
924940

925941
@classmethod
926942
def _concatenate_vertical(cls, tables):
@@ -941,10 +957,6 @@ def merge1d(arrs):
941957
def collect(attr):
942958
return [getattr(arr, attr) for arr in tables]
943959

944-
if not tables:
945-
raise ValueError('need at least one table to concatenate')
946-
if len(tables) == 1:
947-
return tables[0].copy()
948960
domain = tables[0].domain
949961
if any(table.domain != domain for table in tables):
950962
raise ValueError('concatenated tables must have the same domain')
@@ -957,22 +969,12 @@ def collect(attr):
957969
merge1d(collect("W"))
958970
)
959971
conc.ids = np.hstack([t.ids for t in tables])
960-
names = [table.name for table in tables if table.name != "untitled"]
961-
if names:
962-
conc.name = names[0]
963-
# TODO: Add attributes = {} to __init__
964-
conc.attributes = getattr(conc, "attributes", {})
965-
for table in reversed(tables):
966-
conc.attributes.update(table.attributes)
967972
return conc
968973

969974
@classmethod
970975
def _concatenate_horizontal(cls, tables):
971976
"""
972977
"""
973-
if not tables:
974-
raise ValueError('need at least one table to join')
975-
976978
def all_of(objs, names):
977979
return (tuple(getattr(obj, name) for obj in objs)
978980
for name in names)
@@ -992,11 +994,7 @@ def stack(arrs):
992994

993995
parts = all_of(doms, ("attributes", "class_vars", "metas"))
994996
domain = Domain(*(tuple(chain(*lst)) for lst in parts))
995-
table = cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids)
996-
for ta in reversed(table_attrss):
997-
table.attributes.update(ta)
998-
999-
return table
997+
return cls.from_numpy(domain, Xs, Ys, Ms, W, ids=tables[0].ids)
1000998

1001999
def add_column(self, variable, data, to_metas=None):
10021000
"""

Diff for: Orange/data/tests/test_table.py

+12
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,18 @@ def test_concatenate_horizontal(self):
183183
def test_concatenate_invalid_axis(self):
184184
self.assertRaises(ValueError, Table.concatenate, (), axis=2)
185185

186+
def test_concatenate_names(self):
187+
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
188+
189+
tab1 = self._new_table((a, ), (c, ), (d, ), 0)
190+
tab2 = self._new_table((e, ), (), (f, g), 1000)
191+
tab3 = self._new_table((b, ), (), (), 1000)
192+
tab2.name = "tab2"
193+
tab3.name = "tab3"
194+
195+
joined = Table.concatenate((tab1, tab2, tab3), axis=1)
196+
self.assertEqual(joined.name, "tab2")
197+
186198
def test_with_column(self):
187199
a, b, c, d, e, f, g = map(ContinuousVariable, "abcdefg")
188200
col = np.arange(9, 14)

0 commit comments

Comments
 (0)