diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dccec8b..72375fd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: GitHub Actions Demo +name: Sharpy CI run-name: ${{ github.actor }} CI for sharpy on: push: @@ -41,7 +41,8 @@ jobs: rm -rf $CONDA_ROOT cd $GITHUB_WORKSPACE/.. rm -f Miniconda3-*.sh - CPKG=Miniconda3-latest-Linux-x86_64.sh + # CPKG=Miniconda3-latest-Linux-x86_64.sh + CPKG=Miniconda3-py311_24.3.0-0-Linux-x86_64.sh wget -q https://repo.anaconda.com/miniconda/$CPKG bash $CPKG -u -b -f -p $CONDA_ROOT export PATH=$CONDA_ROOT/condabin:$CONDA_ROOT/bin:${PATH} diff --git a/conda-recipe/build.sh b/conda-recipe/build.sh index c8f4503..92a3fa0 100644 --- a/conda-recipe/build.sh +++ b/conda-recipe/build.sh @@ -53,13 +53,12 @@ if [ ! -d "${INSTALLED_DIR}/imex/lib" ]; then rm -rf ${INSTALLED_DIR}/imex IMEX_SHA=$(cat imex_version.txt) if [ ! -d "mlir-extensions" ]; then - git clone --recurse-submodules --branch main --single-branch https://github.com/intel/mlir-extensions + git clone --recurse-submodules https://github.com/intel/mlir-extensions fi pushd mlir-extensions git reset --hard HEAD git fetch --prune git checkout $IMEX_SHA - git apply ${RECIPE_DIR}/imex_*.patch LLVM_SHA=$(cat build_tools/llvm_version.txt) # if [ ! -d "llvm-project" ]; then ln -s ~/github/llvm-project .; fi if [ ! -d "llvm-project" ]; then diff --git a/conda-recipe/imex_findsycl.patch b/conda-recipe/imex_findsycl.patch deleted file mode 100644 index b94d070..0000000 --- a/conda-recipe/imex_findsycl.patch +++ /dev/null @@ -1,40 +0,0 @@ -diff --git a/cmake/modules/FindSyclRuntime.cmake b/cmake/modules/FindSyclRuntime.cmake -index 0eefdf6d..8d8fbd62 100644 ---- a/cmake/modules/FindSyclRuntime.cmake -+++ b/cmake/modules/FindSyclRuntime.cmake -@@ -27,20 +27,26 @@ - - include(FindPackageHandleStandardArgs) - --if(NOT DEFINED ENV{CMPLR_ROOT}) -+if(NOT DEFINED ENV{CMPLR_ROOT} AND NOT DEFINED ENV{SYCL_DIR}) - message(WARNING "Please make sure to install Intel DPC++ Compiler and run setvars.(sh/bat)") - message(WARNING "You can download standalone Intel DPC++ Compiler from https://www.intel.com/content/www/us/en/developer/articles/tool/oneapi-standalone-components.html#compilers") -+ message(Warning "Alternatively, you can set environment SYCL_DIR to the install dir of SYCL") - else() -- get_filename_component(ONEAPI_VER "$ENV{CMPLR_ROOT}" NAME) -- if(ONEAPI_VER VERSION_LESS 2024.0) -- if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux")) -- set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux") -- elseif(WIN32) -- set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows") -- endif() -+ if(DEFINED ENV{SYCL_DIR}) -+ set(SyclRuntime_ROOT "$ENV{SYCL_DIR}") - else() -- set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}") -+ get_filename_component(ONEAPI_VER "$ENV{CMPLR_ROOT}" NAME) -+ if(ONEAPI_VER VERSION_LESS 2024.0) -+ if(LINUX OR (${CMAKE_SYSTEM_NAME} MATCHES "Linux")) -+ set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/linux") -+ elseif(WIN32) -+ set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}/windows") -+ endif() -+ else() -+ set(SyclRuntime_ROOT "$ENV{CMPLR_ROOT}") -+ endif() - endif() -+ - list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include") - list(APPEND SyclRuntime_INCLUDE_DIRS "${SyclRuntime_ROOT}/include/sycl") - diff --git a/examples/stencil-2d.py b/examples/stencil-2d.py index 1ea8bdc..f9c5db9 100644 --- a/examples/stencil-2d.py +++ b/examples/stencil-2d.py @@ -197,8 +197,8 @@ def main(): # * Analyze and output results. # ****************************************************************************** - B = np.spmd.gather(B) - norm = np.linalg.norm(np.reshape(B, n * n), ord=1) + B = np.spmd.gather(np.reshape(B, (n * n,))) + norm = np.linalg.norm(B, ord=1) active_points = (n - 2 * r) ** 2 norm /= active_points diff --git a/imex_version.txt b/imex_version.txt index e3c6a81..f2bc586 100644 --- a/imex_version.txt +++ b/imex_version.txt @@ -1 +1 @@ -617949ac6105f28faeab4fa3018142195d1125c0 +a6109b1005932d8b4c1d2e8ab0ec4abe7411762a diff --git a/setup.py b/setup.py index 2f9b9d8..ab1a4f9 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ def build_cmake(self, ext): os.chdir(str(build_temp)) self.spawn(["cmake", str(cwd)] + cmake_args) if not self.dry_run: - self.spawn(["cmake", "--build", "."] + build_args) + self.spawn(["cmake", "--build", ".", "-j5"] + build_args) # Troubleshooting: if fail on line above then delete all possible # temporary CMake files including "CMakeCache.txt" in top level dir. os.chdir(str(cwd)) diff --git a/sharpy/__init__.py b/sharpy/__init__.py index 35c25dc..56e10cc 100644 --- a/sharpy/__init__.py +++ b/sharpy/__init__.py @@ -107,6 +107,14 @@ def _validate_device(device): f"{func} = lambda start, end, step, endpoint, dtype=float64, device='', team=1: ndarray(_csp.Creator.linspace(start, end, step, endpoint, dtype, _validate_device(device), team))" ) + +for func in api.api_categories["ManipOp"]: + FUNC = func.upper() + if func == "reshape": + exec( + f"{func} = lambda this, shape, cp=None: ndarray(_csp.ManipOp.reshape(this._t, shape, cp))" + ) + for func in api.api_categories["ReduceOp"]: FUNC = func.upper() exec( diff --git a/sharpy/array_api.py b/sharpy/array_api.py index d3c57a2..421d43b 100644 --- a/sharpy/array_api.py +++ b/sharpy/array_api.py @@ -175,7 +175,7 @@ "concat", # (arrays, /, *, axis=0) "expand_dims", # (x, /, *, axis) "flip", # (x, /, *, axis=None) - "reshape", # (x, /, shape) + "reshape", # (x, /, shape, *, copy: bool | None = None) "roll", # (x, /, shift, *, axis=None) "squeeze", # (x, /, axis) "stack", # (arrays, /, *, axis=0) diff --git a/src/ManipOp.cpp b/src/ManipOp.cpp index a1e4ddd..4a69110 100644 --- a/src/ManipOp.cpp +++ b/src/ManipOp.cpp @@ -12,6 +12,7 @@ #include "sharpy/jit/mlir.hpp" #include +#include #include #include @@ -41,7 +42,7 @@ struct DeferredReshape : public Deferred { : ::imex::getIntAttr(builder, COPY_ALWAYS ? true : false, 1); auto aTyp = av.getType().cast<::imex::ndarray::NDArrayType>(); - auto outTyp = aTyp.cloneWith(shape(), aTyp.getElementType()); + auto outTyp = imex::dist::cloneWithShape(aTyp, shape()); auto op = builder.create<::imex::ndarray::ReshapeOp>(loc, outTyp, av, shp, copyA); diff --git a/src/idtr.cpp b/src/idtr.cpp index 3ebacae..7c111f3 100644 --- a/src/idtr.cpp +++ b/src/idtr.cpp @@ -59,31 +59,14 @@ template WaitHandle *mkWaitHandle(T fini) { return new WaitHandle(fini); }; -void _idtr_wait(WaitHandleBase *handle, int64_t lHaloRank, void *lHaloDescr, - int64_t rHaloRank, void *rHaloDescr) { +extern "C" { +void _idtr_wait(WaitHandleBase *handle) { if (handle) { handle->wait(); delete handle; } } -extern "C" { -#define TYPED_WAIT(_sfx) \ - void _idtr_wait_##_sfx(WaitHandleBase *handle, int64_t lHaloRank, \ - void *lHaloDescr, int64_t rHaloRank, \ - void *rHaloDescr) { \ - return _idtr_wait(handle, lHaloRank, lHaloDescr, rHaloRank, rHaloDescr); \ - } \ - _Pragma(STRINGIFY(weak _mlir_ciface__idtr_wait_##_sfx = _idtr_wait_##_sfx)) - -TYPED_WAIT(f64); -TYPED_WAIT(f32); -TYPED_WAIT(i64); -TYPED_WAIT(i32); -TYPED_WAIT(i16); -TYPED_WAIT(i8); -TYPED_WAIT(i1); - #define NO_TRANSCEIVER #ifdef NO_TRANSCEIVER static void initMPIRuntime() { @@ -324,9 +307,9 @@ void copy_(uint64_t d, uint64_t &pos, T *cptr, const int64_t *sizes, } /// copy a number of array elements into a contiguous block of data -void bufferizeN(void *cptr, SHARPY::DTypeId dtype, const int64_t *sizes, - const int64_t *strides, const int64_t *tStarts, - const int64_t *tEnds, uint64_t nd, uint64_t N, void *out) { +void bufferizeN(uint64_t nd, void *cptr, const int64_t *sizes, + const int64_t *strides, SHARPY::DTypeId dtype, uint64_t N, + const int64_t *tStarts, const int64_t *tEnds, void *out) { if (!cptr || !sizes || !strides || !tStarts || !tEnds || !out) { return; } @@ -390,25 +373,29 @@ TYPED_REDUCEALL(i1, bool); /// @brief reshape array /// We assume array is partitioned along the first dimension (only) and /// partitions are ordered by ranks -void _idtr_reshape(SHARPY::DTypeId sharpytype, int64_t lRank, - int64_t *gShapePtr, void *lDataPtr, int64_t *lShapePtr, - int64_t *lStridesPtr, int64_t *lOffsPtr, int64_t oRank, - int64_t *oGShapePtr, void *oDataPtr, int64_t *oShapePtr, - int64_t *oOffsPtr, SHARPY::Transceiver *tc) { +WaitHandleBase *_idtr_copy_reshape(SHARPY::DTypeId sharpytype, + SHARPY::Transceiver *tc, int64_t iNDims, + int64_t *iGShapePtr, int64_t *iOffsPtr, + void *iDataPtr, int64_t *iDataShapePtr, + int64_t *iDataStridesPtr, int64_t oNDims, + int64_t *oGShapePtr, int64_t *oOffsPtr, + void *oDataPtr, int64_t *oDataShapePtr, + int64_t *oDataStridesPtr) { #ifdef NO_TRANSCEIVER initMPIRuntime(); tc = SHARPY::getTransceiver(); #endif - if (!gShapePtr || !lDataPtr || !lShapePtr || !lStridesPtr || !lOffsPtr || - !oGShapePtr || !oDataPtr || !oShapePtr || !oOffsPtr || !tc) { + if (!iGShapePtr || !iOffsPtr || !iDataPtr || !iDataShapePtr || + !iDataStridesPtr || !oGShapePtr || !oOffsPtr || !oDataPtr || + !oDataShapePtr || !oDataStridesPtr || !tc) { throw std::invalid_argument("Fatal: received nullptr in reshape"); } - assert(std::accumulate(&gShapePtr[0], &gShapePtr[lRank], 1, + assert(std::accumulate(&iGShapePtr[0], &iGShapePtr[iNDims], 1, std::multiplies()) == - std::accumulate(&oGShapePtr[0], &oGShapePtr[oRank], 1, + std::accumulate(&oGShapePtr[0], &oGShapePtr[oNDims], 1, std::multiplies())); - assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oRank], 0, + assert(std::accumulate(&oOffsPtr[1], &oOffsPtr[oNDims], 0, std::plus()) == 0); auto N = tc->nranks(); @@ -417,32 +404,37 @@ void _idtr_reshape(SHARPY::DTypeId sharpytype, int64_t lRank, throw std::out_of_range("Fatal: rank must be < number of ranks"); } - int64_t cSz = std::accumulate(&lShapePtr[1], &lShapePtr[lRank], 1, - std::multiplies()); - int64_t mySz = cSz * lShapePtr[0]; - if (mySz / cSz != lShapePtr[0]) { + int64_t icSz = std::accumulate(&iGShapePtr[1], &iGShapePtr[iNDims], 1, + std::multiplies()); + assert(icSz == std::accumulate(&iDataShapePtr[1], &iDataShapePtr[iNDims], 1, + std::multiplies())); + int64_t mySz = icSz * iDataShapePtr[0]; + if (mySz / icSz != iDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } - int64_t myOff = lOffsPtr[0] * cSz; - if (myOff / cSz != lOffsPtr[0]) { + int64_t myOff = iOffsPtr[0] * icSz; + if (myOff / icSz != iOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } int64_t myEnd = myOff + mySz; if (myEnd < myOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } - int64_t tCSz = std::accumulate(&oShapePtr[1], &oShapePtr[oRank], 1, + + int64_t oCSz = std::accumulate(&oGShapePtr[1], &oGShapePtr[oNDims], 1, std::multiplies()); - int64_t myTSz = tCSz * oShapePtr[0]; - if (myTSz / tCSz != oShapePtr[0]) { + assert(oCSz == std::accumulate(&oDataShapePtr[1], &oDataShapePtr[oNDims], 1, + std::multiplies())); + int64_t myOSz = oCSz * oDataShapePtr[0]; + if (myOSz / oCSz != oDataShapePtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } - int64_t myTOff = oOffsPtr[0] * tCSz; - if (myTOff / tCSz != oOffsPtr[0]) { + int64_t myOOff = oOffsPtr[0] * oCSz; + if (myOOff / oCSz != oOffsPtr[0]) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } - int64_t myTEnd = myTOff + myTSz; - if (myTEnd < myTOff) { + int64_t myOEnd = myOOff + myOSz; + if (myOEnd < myOOff) { throw std::overflow_error("Fatal: Integer overflow in reshape"); } @@ -451,8 +443,8 @@ void _idtr_reshape(SHARPY::DTypeId sharpytype, int64_t lRank, ::std::vector buff(4 * N); buff[me * 4 + 0] = myOff; buff[me * 4 + 1] = mySz; - buff[me * 4 + 2] = myTOff; - buff[me * 4 + 3] = myTSz; + buff[me * 4 + 2] = myOOff; + buff[me * 4 + 3] = myOSz; ::std::vector counts(N, 4); ::std::vector dspl(N); for (auto i = 0ul; i < N; ++i) { @@ -490,64 +482,85 @@ void _idtr_reshape(SHARPY::DTypeId sharpytype, int64_t lRank, } // then check if my target part overlaps with the remote local part - if (myTEnd > xOff && myTOff < xEnd) { - auto rOff = std::max(xOff, myTOff); - rszs[i] = std::min(xEnd, myTEnd) - rOff; + if (myOEnd > xOff && myOOff < xEnd) { + auto rOff = std::max(xOff, myOOff); + rszs[i] = std::min(xEnd, myOEnd) - rOff; roffs[i] = i ? roffs[i - 1] + rszs[i - 1] : 0; } } - SHARPY::Buffer outbuff(totSSz * sizeof_dtype(sharpytype), - 2); // FIXME debug value - bufferizeN(lDataPtr, sharpytype, lShapePtr, lStridesPtr, lsOffs.data(), - lsEnds.data(), lRank, N, outbuff.data()); + SHARPY::Buffer outbuff(totSSz * sizeof_dtype(sharpytype), 2); + bufferizeN(iNDims, iDataPtr, iDataShapePtr, iDataStridesPtr, sharpytype, N, + lsOffs.data(), lsEnds.data(), outbuff.data()); auto hdl = tc->alltoall(outbuff.data(), sszs.data(), soffs.data(), sharpytype, oDataPtr, rszs.data(), roffs.data()); - tc->wait(hdl); + + if (true || no_async) { // FIXME remove true once IMEX is fixed + tc->wait(hdl); + return nullptr; + } + + auto wait = [tc = tc, hdl = hdl, outbuff = std::move(outbuff), + sszs = std::move(sszs), soffs = std::move(soffs), + rszs = std::move(rszs), + roffs = std::move(roffs)]() { tc->wait(hdl); }; + assert(outbuff.empty() && sszs.empty() && soffs.empty() && rszs.empty() && + roffs.empty()); + return mkWaitHandle(wait); } /// @brief reshape array template -void _idtr_reshape(int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank, - void *lOffsDescr, int64_t lRank, void *lDescr, - int64_t oGShapeRank, void *oGShapeDescr, int64_t oOffsRank, - void *oOffsDescr, int64_t oRank, void *oDescr, - SHARPY::Transceiver *tc) { +WaitHandleBase * +_idtr_copy_reshape(SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, + int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, + void *iDataDescr, int64_t oNSzs, void *oGShapeDescr, + int64_t oNOffs, void *oLOffsDescr, int64_t oNDims, + void *oDataDescr) { + + if (!iGShapeDescr || !iLOffsDescr || !iDataDescr || !oGShapeDescr || + !oLOffsDescr || !oDataDescr) { + throw std::invalid_argument( + "Fatal error: received nullptr in update_halo."); + } auto sharpytype = SHARPY::DTYPE::value; - SHARPY::UnrankedMemRefType lData(lRank, lDescr); - SHARPY::UnrankedMemRefType oData(oRank, oDescr); - - _idtr_reshape(sharpytype, lRank, MRIdx1d(gShapeRank, gShapeDescr).data(), - lData.data(), lData.sizes(), lData.strides(), - MRIdx1d(lOffsRank, lOffsDescr).data(), oRank, - MRIdx1d(oGShapeRank, oGShapeDescr).data(), oData.data(), - oData.sizes(), MRIdx1d(oOffsRank, oOffsDescr).data(), tc); + // Construct unranked memrefs for metadata and data + MRIdx1d iGShape(iNSzs, iGShapeDescr); + MRIdx1d iOffs(iNOffs, iLOffsDescr); + SHARPY::UnrankedMemRefType iData(iNDims, iDataDescr); + MRIdx1d oGShape(oNSzs, oGShapeDescr); + MRIdx1d oOffs(oNOffs, oLOffsDescr); + SHARPY::UnrankedMemRefType oData(oNDims, oDataDescr); + + return _idtr_copy_reshape( + sharpytype, tc, iNDims, iGShape.data(), iOffs.data(), iData.data(), + iData.sizes(), iData.strides(), oNDims, oGShape.data(), oOffs.data(), + oData.data(), oData.sizes(), oData.strides()); } extern "C" { - -#define TYPED_RESHAPE(_sfx, _typ) \ - void _idtr_reshape_##_sfx( \ - int64_t gShapeRank, void *gShapeDescr, int64_t lOffsRank, \ - void *lOffsDescr, int64_t rank, void *lDescr, int64_t oGShapeRank, \ - void *oGShapeDescr, int64_t oOffsRank, void *oOffsDescr, int64_t oRank, \ - void *oDescr, SHARPY::Transceiver *tc) { \ - _idtr_reshape<_typ>(gShapeRank, gShapeDescr, lOffsRank, lOffsDescr, rank, \ - lDescr, oGShapeRank, oGShapeDescr, oOffsRank, \ - oOffsDescr, oRank, oDescr, tc); \ +#define TYPED_COPY_RESHAPE(_sfx, _typ) \ + void *_idtr_copy_reshape_##_sfx( \ + SHARPY::Transceiver *tc, int64_t iNSzs, void *iGShapeDescr, \ + int64_t iNOffs, void *iLOffsDescr, int64_t iNDims, void *iLDescr, \ + int64_t oNSzs, void *oGShapeDescr, int64_t oNOffs, void *oLOffsDescr, \ + int64_t oNDims, void *oLDescr) { \ + return _idtr_copy_reshape<_typ>( \ + tc, iNSzs, iGShapeDescr, iNOffs, iLOffsDescr, iNDims, iLDescr, oNSzs, \ + oGShapeDescr, oNOffs, oLOffsDescr, oNDims, oLDescr); \ } \ - _Pragma(STRINGIFY(weak _mlir_ciface__idtr_reshape_##_sfx = \ - _idtr_reshape_##_sfx)) + _Pragma(STRINGIFY(weak _mlir_ciface__idtr_copy_reshape_##_sfx = \ + _idtr_copy_reshape_##_sfx)) -TYPED_RESHAPE(f64, double); -TYPED_RESHAPE(f32, float); -TYPED_RESHAPE(i64, int64_t); -TYPED_RESHAPE(i32, int32_t); -TYPED_RESHAPE(i16, int16_t); -TYPED_RESHAPE(i8, int8_t); -TYPED_RESHAPE(i1, bool); +TYPED_COPY_RESHAPE(f64, double); +TYPED_COPY_RESHAPE(f32, float); +TYPED_COPY_RESHAPE(i64, int64_t); +TYPED_COPY_RESHAPE(i32, int32_t); +TYPED_COPY_RESHAPE(i16, int16_t); +TYPED_COPY_RESHAPE(i8, int8_t); +TYPED_COPY_RESHAPE(i1, bool); } // extern "C" @@ -966,34 +979,3 @@ TYPED_UPDATE_HALO(i8, int8_t); TYPED_UPDATE_HALO(i1, bool); } // extern "C" - -// debug helper -void _idtr_extractslice(int64_t *slcOffs, int64_t *slcSizes, - int64_t *slcStrides, int64_t *tOffs, int64_t *tSizes, - int64_t *lSlcOffsets, int64_t *lSlcSizes, - int64_t *gSlcOffsets) { - if (slcOffs) - std::cerr << "slcOffs: " << slcOffs[0] << " " << slcOffs[1] << std::endl; - if (slcSizes) - std::cerr << "slcSizes: " << slcSizes[0] << " " << slcSizes[1] << std::endl; - if (slcStrides) - std::cerr << "slcStrides: " << slcStrides[0] << " " << slcStrides[1] - << std::endl; - if (tOffs) - std::cerr << "tOffs: " << tOffs[0] << " " << tOffs[1] << std::endl; - if (tSizes) - std::cerr << "tSizes: " << tSizes[0] << " " << tSizes[1] << std::endl; - if (lSlcOffsets) - std::cerr << "lSlcOffsets: " << lSlcOffsets[0] << " " << lSlcOffsets[1] - << std::endl; - if (lSlcSizes) - std::cerr << "lSlcSizes: " << lSlcSizes[0] << " " << lSlcSizes[1] - << std::endl; - if (gSlcOffsets) - std::cerr << "gSlcOffsets: " << gSlcOffsets[0] << " " << gSlcOffsets[1] - << std::endl; -} - -extern "C" { -void _debugFunc() { std::cerr << "_debugfunc\n"; } -} // extern "C" diff --git a/test/test_ewb.py b/test/test_ewb.py index 4c0e57a..3d9af51 100644 --- a/test/test_ewb.py +++ b/test/test_ewb.py @@ -82,7 +82,6 @@ def test_add_shifted1(self): v = 8 * 16 * 3 assert float(r1) == v - @pytest.mark.skip(reason="FIXME reshape") def test_add_shifted2(self): def doit(aapi, **kwargs): a = aapi.reshape( diff --git a/test/test_io.py b/test/test_io.py index c9b8aff..e67c207 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -27,7 +27,6 @@ def test_device_input_invalid(self, device): with pytest.raises(ValueError, match="Invalid device string: *"): sp.ones((4,), device=device) - @pytest.mark.skip(reason="FIXME reshape") def test_to_numpy2d(self): a = sp.reshape( sp.arange(0, 110, 1, dtype=sp.float32, device=device), [11, 10] @@ -44,7 +43,6 @@ def test_to_numpy1d(self): v = np.sum(np.arange(0, 110, 1, dtype=np.float32)) assert float(c) == v - @pytest.mark.skip(reason="FIXME reshape") def test_to_numpy_strided(self): a = sp.reshape( sp.arange(0, 110, 1, dtype=sp.float32, device=device), [11, 10] diff --git a/test/test_manip.py b/test/test_manip.py index 3e64a21..8885037 100644 --- a/test/test_manip.py +++ b/test/test_manip.py @@ -10,7 +10,6 @@ class TestManip: - @pytest.mark.skip(reason="FIXME reshape") def test_reshape1(self): def doit(aapi, **kwargs): a = aapi.arange(0, 12 * 11, 1, aapi.int32, **kwargs) @@ -18,7 +17,6 @@ def doit(aapi, **kwargs): assert runAndCompare(doit) - @pytest.mark.skip(reason="FIXME reshape") def test_reshape2(self): def doit(aapi, **kwargs): a = aapi.arange(0, 12 * 11, 1, aapi.int32, **kwargs) diff --git a/test/test_red.py b/test/test_red.py index 4e70de5..196e55f 100644 --- a/test/test_red.py +++ b/test/test_red.py @@ -1,9 +1,7 @@ -import pytest from utils import runAndCompare class TestRed: - @pytest.mark.skip(reason="FIXME reshape") def test_sum(self): def doit(aapi, **kwargs): a = aapi.arange(0, 64, 1, dtype=aapi.int64, **kwargs) diff --git a/test/test_spmd.py b/test/test_spmd.py index 7507f50..a39eb23 100644 --- a/test/test_spmd.py +++ b/test/test_spmd.py @@ -52,18 +52,13 @@ def test_get_locals_of_view(self): assert float(c) == v MPI.COMM_WORLD.barrier() - @pytest.mark.skip(reason="FIXME reshape") def test_gather1(self): a = sp.reshape( sp.arange(0, 110, 1, dtype=sp.float32, device=device), [11, 10] ) b = sp.spmd.gather(a) c = np.sum(b) - v = np.sum( - np.reshape( - np.arange(0, 110, 1, dtype=np.float32, device=device), (11, 10) - ) - ) + v = np.sum(np.reshape(np.arange(0, 110, 1, dtype=np.float32), (11, 10))) assert float(c) == v MPI.COMM_WORLD.barrier() @@ -81,7 +76,6 @@ def test_gather_0d(self): assert float(b) == 5 MPI.COMM_WORLD.barrier() - @pytest.mark.skip(reason="FIXME reshape") def test_gather_strided1(self): a = sp.reshape( sp.arange(0, 110, 1, dtype=sp.float32, device=device), [11, 10]