Skip to content

Commit 62f26b1

Browse files
committed
Native (C++) implementation of table_join
1 parent 011c1a2 commit 62f26b1

File tree

6 files changed

+452
-4
lines changed

6 files changed

+452
-4
lines changed

.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,9 @@
66
*.tgz
77
*.h5
88
*.png
9+
*.bak
10+
*.so
11+
.smhist
12+
EGG-INFO
13+
build
914
DEADJOE

setup.py

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env python
2+
3+
import os, os.path
4+
5+
def suppress_keyboard_interrupt_message():
6+
old_excepthook = sys.excepthook
7+
8+
def new_hook(type, value, traceback):
9+
if type != exceptions.KeyboardInterrupt:
10+
old_excepthook(type, value, traceback)
11+
else:
12+
pass
13+
14+
sys.excepthook = new_hook
15+
16+
17+
# Use NUMPY_INCLUDE environment variable to set where to find NumPy
18+
numpy_include=os.getenv('NUMPY_INCLUDE', '/opt/python2.7/lib/python2.7/site-packages/numpy/core/include/')
19+
if not os.path.isfile(numpy_include + '/numpy/arrayobject.h'):
20+
print >> sys.stderr, "Failed to find " . numpy_include + '/numpy/arrayobject.h'
21+
print >> sys.stderr, "Error: could not find arrayobject.h. Please set the NumPy include path using NUMPY_INCLUDE environment variable"
22+
exit(-1)
23+
24+
# ------ no changes below! If you need to change, it's a bug! -------
25+
from distutils.core import setup, Extension
26+
from sys import platform
27+
28+
import numpy
29+
inc = ['src', numpy_include]
30+
31+
longdesc = """Large Survey Database"""
32+
33+
args = {
34+
'name' : "skysurvey",
35+
'version' : "0.1",
36+
'description' : "Large Survey Database Python Module",
37+
'long_description' : longdesc,
38+
'license' : "GPLv2",
39+
'author' : "Mario Juric",
40+
'author_email' : "[email protected]",
41+
'maintainer' : "Mario Juric",
42+
'maintainer_email' : "[email protected]",
43+
'url' : "http://mwscience.net/lsd",
44+
'download_url' : "http://mwscience.net/lsd/download",
45+
'classifiers' : [
46+
'Development Status :: 3 - Alpha',
47+
'Intended Audience :: Science/Research',
48+
'Intended Audience :: Developers',
49+
'License :: OSI Approved :: GNU General Public License (GPL)',
50+
'Programming Language :: C++',
51+
'Programming Language :: Python :: 2',
52+
'Programming Language :: Python :: 2.7',
53+
'Operating System :: POSIX :: Linux',
54+
'Topic :: Database',
55+
'Topic :: Scientific/Engineering :: Astronomy'
56+
],
57+
'packages' : ['skysurvey'],
58+
'ext_modules' : [Extension('skysurvey.native', ['src/native.cpp'], include_dirs=inc)]
59+
}
60+
61+
setup(**args)

skysurvey/catalog.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1229,7 +1229,7 @@ def extract_columns(rows, cols=All):
12291229

12301230
return ret
12311231

1232-
def table_join(id1, id2, m1, m2, join_type='outer'):
1232+
def table_join_py(id1, id2, m1, m2, join_type):
12331233
# The algorithm assumes id1 and id2 have no
12341234
# duplicated elements
12351235
if False:
@@ -1320,6 +1320,9 @@ def table_join(id1, id2, m1, m2, join_type='outer'):
13201320

13211321
return (idx1, idx2, isnull)
13221322

1323+
from native import table_join
1324+
#table_join = table_join_py
1325+
13231326
def in_array(needles, haystack):
13241327
""" Return a boolean array of len(needles) set to
13251328
True for each needle that is found in the haystack.
@@ -1589,9 +1592,9 @@ def fetch_cached_tablet(cat, cell_id, table):
15891592
self.orig_rows[cat2.name] = len(rows2)
15901593

15911594
# Join the tables (jmap and rows2), using (m1, m2) linkage information
1592-
table_join.cell_id = cell_id # debugging (remove once happy)
1593-
table_join.cat = cat # debugging (remove once happy)
1594-
(idx1, idx2, isnull) = table_join(self.keys, id2, m1, m2, join_type=join_type)
1595+
#table_join.cell_id = cell_id # debugging (remove once happy)
1596+
#table_join.cat = cat # debugging (remove once happy)
1597+
(idx1, idx2, isnull) = table_join(self.keys, id2, m1, m2, join_type)
15951598

15961599
# Reject rows that are out of the time interval in this table.
15971600
# We have to do this here as well, to support filtering on time in static_sky->temporal_sky joins

src/native.cpp

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
#include <Python.h>
2+
#include <numpy/arrayobject.h>
3+
#include "table_join.h"
4+
#include <iostream>
5+
6+
/***************************** Module *******************************/
7+
#define DOCSTR_LSD_NATIVE_MODULE \
8+
"skysurvey.native -- native code accelerators for LSD"
9+
10+
////////////////////////////////////////////////////////////
11+
12+
statichere PyObject *NativeError;
13+
14+
// C++ exception class that sets a Python exception and throws
15+
struct E
16+
{
17+
E(PyObject *err = NULL, const std::string &msg = "")
18+
{
19+
if(err)
20+
{
21+
PyErr_SetString(err, msg.c_str());
22+
}
23+
}
24+
};
25+
26+
struct PyOutput
27+
{
28+
/* Aux class for table_join that stores the output directly into NumPy arrays */
29+
npy_int64 *idx1, *idx2;
30+
npy_bool *isnull;
31+
int64_t size, reserved;
32+
33+
PyArrayObject *o_idx1, *o_idx2, *o_isnull;
34+
35+
PyOutput() : size(0), reserved(0), o_idx1(NULL), o_idx2(NULL), o_isnull(NULL)
36+
{
37+
npy_intp dims = reserved;
38+
39+
o_idx1 = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_INT64);
40+
o_idx2 = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_INT64);
41+
o_isnull = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_BOOL);
42+
}
43+
44+
void resize(npy_intp dims)
45+
{
46+
PyArray_Dims shape;
47+
shape.ptr = &dims;
48+
shape.len = 1;
49+
50+
if(PyArray_Resize(o_idx1, &shape, false, NPY_CORDER) == NULL) throw E();
51+
if(PyArray_Resize(o_idx2, &shape, false, NPY_CORDER) == NULL) throw E();
52+
if(PyArray_Resize(o_isnull, &shape, false, NPY_CORDER) == NULL) throw E();
53+
54+
idx1 = (npy_int64 *)o_idx1->data;
55+
idx2 = (npy_int64 *)o_idx2->data;
56+
isnull = (npy_bool *)o_isnull->data;
57+
58+
reserved = dims;
59+
}
60+
61+
void push_back(int64_t i1, int64_t i2, bool in)
62+
{
63+
if(size >= reserved)
64+
{
65+
resize(2*std::max(size, int64_t(1)));
66+
}
67+
idx1[size] = i1;
68+
idx2[size] = i2;
69+
isnull[size] = in;
70+
size++;
71+
}
72+
73+
~PyOutput()
74+
{
75+
Py_XDECREF(o_idx1);
76+
Py_XDECREF(o_idx2);
77+
Py_XDECREF(o_isnull);
78+
}
79+
};
80+
81+
// Python interface: (idx1, idx2, isnull) = table_join(idx1, idx2, m1, m2, join_type)
82+
#define DOCSTR_TABLE_JOIN \
83+
"idx1, idx2, isnull = table_join(id1, id2, m1, m2)\n\
84+
\n\
85+
Join columns id1 and id2, using linkage information\n\
86+
in (m1, m2).\n\
87+
\n\
88+
:Arguments:\n\
89+
- id1 : First table key\n\
90+
- id2 : Second table key\n\
91+
- m1 : First table link key\n\
92+
- m2 : Second table link key\n\
93+
\n\
94+
The output will be arrays of indices\n\
95+
idx1, idx2, and isnull such that:\n\
96+
\n\
97+
id1[idx1], id2[idx2]\n\
98+
\n\
99+
(where indexing is performed in NumPy-like vector sense)\n\
100+
will form the resulting JOIN-ed table.\n\
101+
\n\
102+
If join_type=='inner', the result is roughly equivalent\n\
103+
to the result of the following SQL fragment:\n\
104+
\n\
105+
SELECT id1, id2 ... WHERE id1 == m1 and m2 == id2\n\
106+
\n\
107+
If join_type=='ouuter', the result will include those\n\
108+
rows where id1 has no id2 counterparts. For such rows\n\
109+
idx2 will be set to 0, but isnull will be true.\n\
110+
\n\
111+
Both id1 and id2 are allowed to have repeated elements.\n\
112+
"
113+
static PyObject *Py_table_join(PyObject *self, PyObject *args)
114+
{
115+
PyObject *ret = NULL;
116+
117+
PyObject *id1 = NULL, *id2 = NULL, *m1 = NULL, *m2 = NULL;
118+
const char *join_type = NULL;
119+
120+
try
121+
{
122+
PyObject *id1_, *id2_, *m1_, *m2_;
123+
if (! PyArg_ParseTuple(args, "OOOOs", &id1_, &id2_, &m1_, &m2_, &join_type)) throw E(PyExc_Exception, "Wrong number or type of args");
124+
125+
if ((id1 = PyArray_ContiguousFromAny(id1_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "id1 is not a 1D uint64 NumPy array");
126+
if ((id2 = PyArray_ContiguousFromAny(id2_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of id2 to 1D NumPy array");
127+
if ((m1 = PyArray_ContiguousFromAny(m1_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of m1 to 1D NumPy array");
128+
if ((m2 = PyArray_ContiguousFromAny(m2_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of m2 to 1D NumPy array");
129+
130+
if (PyArray_DIM(m1, 0) != PyArray_DIM(m2, 0)) throw E(PyExc_Exception, "The sizes of len(m1) and len(m2) must be the same");
131+
132+
#define DATAPTR(type, obj) ((type*)PyArray_DATA(obj))
133+
PyOutput o;
134+
table_join(
135+
o,
136+
DATAPTR(uint64_t, id1), PyArray_Size(id1),
137+
DATAPTR(uint64_t, id2), PyArray_Size(id2),
138+
DATAPTR(uint64_t, m1), DATAPTR(uint64_t, m2), PyArray_Size(m2),
139+
join_type
140+
);
141+
#undef DATAPTR
142+
o.resize(o.size);
143+
144+
ret = PyTuple_New(3);
145+
// because PyTuple will take ownership (and PyOutput will do a DECREF on destruction).
146+
Py_INCREF(o.o_idx1);
147+
Py_INCREF(o.o_idx2);
148+
Py_INCREF(o.o_isnull);
149+
PyTuple_SetItem(ret, 0, (PyObject *)o.o_idx1);
150+
PyTuple_SetItem(ret, 1, (PyObject *)o.o_idx2);
151+
PyTuple_SetItem(ret, 2, (PyObject *)o.o_isnull);
152+
}
153+
catch(const E& e)
154+
{
155+
ret = NULL;
156+
}
157+
158+
Py_XDECREF(id1);
159+
Py_XDECREF(id2);
160+
Py_XDECREF(m1);
161+
Py_XDECREF(m2);
162+
163+
return ret;
164+
}
165+
166+
167+
static PyMethodDef nativeMethods[] =
168+
{
169+
{"table_join", (PyCFunction)Py_table_join, METH_VARARGS, DOCSTR_TABLE_JOIN},
170+
{NULL} /* Sentinel */
171+
};
172+
173+
extern "C" PyMODINIT_FUNC initnative(void)
174+
{
175+
// initialize our module
176+
PyObject *m = Py_InitModule3("native", nativeMethods, DOCSTR_LSD_NATIVE_MODULE);
177+
178+
// initialize numpy
179+
import_array();
180+
181+
// add our exception type
182+
NativeError = PyErr_NewException("native.error", NULL, NULL);
183+
Py_INCREF(NativeError);
184+
PyModule_AddObject(m, "error", NativeError);
185+
}

src/table_join.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include "table_join.h"
2+
#include <cstdlib>
3+
#include <iostream>
4+
5+
/*
6+
C++ test code for table_join
7+
*/
8+
9+
struct Output
10+
{
11+
int64_t *idx1, *idx2;
12+
bool *isnull;
13+
int64_t size, reserved;
14+
15+
Output() : size(0), reserved(0), idx1(NULL), idx2(NULL), isnull(NULL) {}
16+
17+
void push_back(int64_t i1, int64_t i2, int64_t in)
18+
{
19+
if(size+1 >= reserved)
20+
{
21+
reserved = 2*std::max(size, int64_t(1));
22+
idx1 = (int64_t *)realloc(idx1, sizeof(*idx1)*reserved);
23+
idx2 = (int64_t *)realloc(idx2, sizeof(*idx2)*reserved);
24+
isnull = (bool *)realloc(isnull, sizeof(*isnull)*reserved);
25+
}
26+
idx1[size] = i1;
27+
idx2[size] = i2;
28+
isnull[size] = in;
29+
size++;
30+
}
31+
32+
~Output()
33+
{
34+
free(idx1); free(idx2); free(isnull);
35+
}
36+
};
37+
38+
int main(int argc, char **argv)
39+
{
40+
Output o;
41+
uint64_t t1[] = {3, 11, 4, 2, 7, 4};
42+
uint64_t m1[] = {2, 4, 11, 4, 8, 422, 0, 4};
43+
uint64_t m2[] = {7, 1, 2, 3, 2, 321, 6, 42};
44+
uint64_t t2[] = {2, 7, 1, 3, 1, 578, 422};
45+
46+
table_join(o,
47+
t1, sizeof(t1)/sizeof(t1[0]),
48+
t2, sizeof(t2)/sizeof(t2[0]),
49+
m1, m2, sizeof(m2)/sizeof(m2[0]),
50+
"outer"
51+
);
52+
53+
for(int i = 0; i != o.size; i++)
54+
{
55+
std::cout << "(" << o.idx1[i] << ") " << t1[o.idx1[i]] << " -- " << t2[o.idx2[i]] << " (" << o.idx2[i] << ") isnull=" << o.isnull[i] << "\n";
56+
}
57+
}

0 commit comments

Comments
 (0)