|
| 1 | +#include <Python.h> |
| 2 | +#include <numpy/arrayobject.h> |
| 3 | +#include "table_join.h" |
| 4 | +#include <iostream> |
| 5 | + |
| 6 | +/***************************** Module *******************************/ |
| 7 | +#define DOCSTR_LSD_NATIVE_MODULE \ |
| 8 | +"skysurvey.native -- native code accelerators for LSD" |
| 9 | + |
| 10 | +//////////////////////////////////////////////////////////// |
| 11 | + |
| 12 | +statichere PyObject *NativeError; |
| 13 | + |
| 14 | +// C++ exception class that sets a Python exception and throws |
| 15 | +struct E |
| 16 | +{ |
| 17 | + E(PyObject *err = NULL, const std::string &msg = "") |
| 18 | + { |
| 19 | + if(err) |
| 20 | + { |
| 21 | + PyErr_SetString(err, msg.c_str()); |
| 22 | + } |
| 23 | + } |
| 24 | +}; |
| 25 | + |
| 26 | +struct PyOutput |
| 27 | +{ |
| 28 | + /* Aux class for table_join that stores the output directly into NumPy arrays */ |
| 29 | + npy_int64 *idx1, *idx2; |
| 30 | + npy_bool *isnull; |
| 31 | + int64_t size, reserved; |
| 32 | + |
| 33 | + PyArrayObject *o_idx1, *o_idx2, *o_isnull; |
| 34 | + |
| 35 | + PyOutput() : size(0), reserved(0), o_idx1(NULL), o_idx2(NULL), o_isnull(NULL) |
| 36 | + { |
| 37 | + npy_intp dims = reserved; |
| 38 | + |
| 39 | + o_idx1 = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_INT64); |
| 40 | + o_idx2 = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_INT64); |
| 41 | + o_isnull = (PyArrayObject *)PyArray_SimpleNew(1, &dims, PyArray_BOOL); |
| 42 | + } |
| 43 | + |
| 44 | + void resize(npy_intp dims) |
| 45 | + { |
| 46 | + PyArray_Dims shape; |
| 47 | + shape.ptr = &dims; |
| 48 | + shape.len = 1; |
| 49 | + |
| 50 | + if(PyArray_Resize(o_idx1, &shape, false, NPY_CORDER) == NULL) throw E(); |
| 51 | + if(PyArray_Resize(o_idx2, &shape, false, NPY_CORDER) == NULL) throw E(); |
| 52 | + if(PyArray_Resize(o_isnull, &shape, false, NPY_CORDER) == NULL) throw E(); |
| 53 | + |
| 54 | + idx1 = (npy_int64 *)o_idx1->data; |
| 55 | + idx2 = (npy_int64 *)o_idx2->data; |
| 56 | + isnull = (npy_bool *)o_isnull->data; |
| 57 | + |
| 58 | + reserved = dims; |
| 59 | + } |
| 60 | + |
| 61 | + void push_back(int64_t i1, int64_t i2, bool in) |
| 62 | + { |
| 63 | + if(size >= reserved) |
| 64 | + { |
| 65 | + resize(2*std::max(size, int64_t(1))); |
| 66 | + } |
| 67 | + idx1[size] = i1; |
| 68 | + idx2[size] = i2; |
| 69 | + isnull[size] = in; |
| 70 | + size++; |
| 71 | + } |
| 72 | + |
| 73 | + ~PyOutput() |
| 74 | + { |
| 75 | + Py_XDECREF(o_idx1); |
| 76 | + Py_XDECREF(o_idx2); |
| 77 | + Py_XDECREF(o_isnull); |
| 78 | + } |
| 79 | +}; |
| 80 | + |
| 81 | +// Python interface: (idx1, idx2, isnull) = table_join(idx1, idx2, m1, m2, join_type) |
| 82 | +#define DOCSTR_TABLE_JOIN \ |
| 83 | +"idx1, idx2, isnull = table_join(id1, id2, m1, m2)\n\ |
| 84 | +\n\ |
| 85 | +Join columns id1 and id2, using linkage information\n\ |
| 86 | +in (m1, m2).\n\ |
| 87 | +\n\ |
| 88 | +:Arguments:\n\ |
| 89 | + - id1 : First table key\n\ |
| 90 | + - id2 : Second table key\n\ |
| 91 | + - m1 : First table link key\n\ |
| 92 | + - m2 : Second table link key\n\ |
| 93 | +\n\ |
| 94 | +The output will be arrays of indices\n\ |
| 95 | +idx1, idx2, and isnull such that:\n\ |
| 96 | +\n\ |
| 97 | + id1[idx1], id2[idx2]\n\ |
| 98 | +\n\ |
| 99 | +(where indexing is performed in NumPy-like vector sense)\n\ |
| 100 | +will form the resulting JOIN-ed table.\n\ |
| 101 | +\n\ |
| 102 | +If join_type=='inner', the result is roughly equivalent\n\ |
| 103 | +to the result of the following SQL fragment:\n\ |
| 104 | +\n\ |
| 105 | + SELECT id1, id2 ... WHERE id1 == m1 and m2 == id2\n\ |
| 106 | +\n\ |
| 107 | +If join_type=='ouuter', the result will include those\n\ |
| 108 | +rows where id1 has no id2 counterparts. For such rows\n\ |
| 109 | +idx2 will be set to 0, but isnull will be true.\n\ |
| 110 | +\n\ |
| 111 | +Both id1 and id2 are allowed to have repeated elements.\n\ |
| 112 | +" |
| 113 | +static PyObject *Py_table_join(PyObject *self, PyObject *args) |
| 114 | +{ |
| 115 | + PyObject *ret = NULL; |
| 116 | + |
| 117 | + PyObject *id1 = NULL, *id2 = NULL, *m1 = NULL, *m2 = NULL; |
| 118 | + const char *join_type = NULL; |
| 119 | + |
| 120 | + try |
| 121 | + { |
| 122 | + PyObject *id1_, *id2_, *m1_, *m2_; |
| 123 | + if (! PyArg_ParseTuple(args, "OOOOs", &id1_, &id2_, &m1_, &m2_, &join_type)) throw E(PyExc_Exception, "Wrong number or type of args"); |
| 124 | + |
| 125 | + if ((id1 = PyArray_ContiguousFromAny(id1_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "id1 is not a 1D uint64 NumPy array"); |
| 126 | + if ((id2 = PyArray_ContiguousFromAny(id2_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of id2 to 1D NumPy array"); |
| 127 | + if ((m1 = PyArray_ContiguousFromAny(m1_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of m1 to 1D NumPy array"); |
| 128 | + if ((m2 = PyArray_ContiguousFromAny(m2_, PyArray_UINT64, 1, 1)) == NULL) throw E(PyExc_Exception, "Could not cast the value of m2 to 1D NumPy array"); |
| 129 | + |
| 130 | + if (PyArray_DIM(m1, 0) != PyArray_DIM(m2, 0)) throw E(PyExc_Exception, "The sizes of len(m1) and len(m2) must be the same"); |
| 131 | + |
| 132 | + #define DATAPTR(type, obj) ((type*)PyArray_DATA(obj)) |
| 133 | + PyOutput o; |
| 134 | + table_join( |
| 135 | + o, |
| 136 | + DATAPTR(uint64_t, id1), PyArray_Size(id1), |
| 137 | + DATAPTR(uint64_t, id2), PyArray_Size(id2), |
| 138 | + DATAPTR(uint64_t, m1), DATAPTR(uint64_t, m2), PyArray_Size(m2), |
| 139 | + join_type |
| 140 | + ); |
| 141 | + #undef DATAPTR |
| 142 | + o.resize(o.size); |
| 143 | + |
| 144 | + ret = PyTuple_New(3); |
| 145 | + // because PyTuple will take ownership (and PyOutput will do a DECREF on destruction). |
| 146 | + Py_INCREF(o.o_idx1); |
| 147 | + Py_INCREF(o.o_idx2); |
| 148 | + Py_INCREF(o.o_isnull); |
| 149 | + PyTuple_SetItem(ret, 0, (PyObject *)o.o_idx1); |
| 150 | + PyTuple_SetItem(ret, 1, (PyObject *)o.o_idx2); |
| 151 | + PyTuple_SetItem(ret, 2, (PyObject *)o.o_isnull); |
| 152 | + } |
| 153 | + catch(const E& e) |
| 154 | + { |
| 155 | + ret = NULL; |
| 156 | + } |
| 157 | + |
| 158 | + Py_XDECREF(id1); |
| 159 | + Py_XDECREF(id2); |
| 160 | + Py_XDECREF(m1); |
| 161 | + Py_XDECREF(m2); |
| 162 | + |
| 163 | + return ret; |
| 164 | +} |
| 165 | + |
| 166 | + |
| 167 | +static PyMethodDef nativeMethods[] = |
| 168 | +{ |
| 169 | + {"table_join", (PyCFunction)Py_table_join, METH_VARARGS, DOCSTR_TABLE_JOIN}, |
| 170 | + {NULL} /* Sentinel */ |
| 171 | +}; |
| 172 | + |
| 173 | +extern "C" PyMODINIT_FUNC initnative(void) |
| 174 | +{ |
| 175 | + // initialize our module |
| 176 | + PyObject *m = Py_InitModule3("native", nativeMethods, DOCSTR_LSD_NATIVE_MODULE); |
| 177 | + |
| 178 | + // initialize numpy |
| 179 | + import_array(); |
| 180 | + |
| 181 | + // add our exception type |
| 182 | + NativeError = PyErr_NewException("native.error", NULL, NULL); |
| 183 | + Py_INCREF(NativeError); |
| 184 | + PyModule_AddObject(m, "error", NativeError); |
| 185 | +} |
0 commit comments