From 827658a93f84854770ca7b46ab88eb4c5c395285 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 11 Aug 2019 02:55:19 +0000 Subject: [PATCH] tc comprehension integration Ref. SINGA-482 --- cmake/Dependencies.cmake | 44 ++++++ include/singa/core/tensor.h | 101 +++++++++++++ src/api/core_tensor.i | 22 +++ src/core/tensor/tensor.cc | 165 ++++++++++++++++++++++ test/CMakeLists.txt | 4 +- test/singa/test_tensor_math.cc | 147 +++++++++++++++++++ tool/docker/devel/ubuntu/cuda9/Dockerfile | 59 +++++++- 7 files changed, 539 insertions(+), 3 deletions(-) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index d1d8060171..6729ac3f8f 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -149,3 +149,47 @@ IF(USE_MKLDNN) INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR}) LIST(APPEND SINGA_LINKER_LIBS ${MKLDNN_LIBRARIES}) ENDIF() + + +### Tensor comprehensions +### Tensor comprehensions +### Tensor comprehensions +# the path should be consistent with the inlcude path in src +INCLUDE_DIRECTORIES(/root/TensorComprehensions) +INCLUDE_DIRECTORIES(/root/TensorComprehensions/tc/version) +INCLUDE_DIRECTORIES(/root/TensorComprehensions/build) + +# polyhedral model required +INCLUDE_DIRECTORIES(/root/TensorComprehensions/isl_interface/include) + +# dlpack +INCLUDE_DIRECTORIES(/root/TensorComprehensions/third-party/dlpack/include) +# Halide +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/include/Halide) + +# llvm +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/include) + +# torch ATen header TO DELETE +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/lib/python3.6/site-packages/torch/lib/include) + +# find Halide lib +set(HALIDE_PREFIX "/root/conda/envs/tc_build") +find_library(HALIDE_LIBRARIES REQUIRED NAMES Halide PATHS ${HALIDE_PREFIX} PATH_SUFFIXES lib lib64 NO_DEFAULT_PATH) +message(STATUS "Found Halide.so file: ${HALIDE_LIBRARIES}") + +# find tc lib +link_directories(/root/TensorComprehensions/build/tc/aten) +link_directories(/root/TensorComprehensions/build/tc/lang) +link_directories(/root/TensorComprehensions/build/tc/core) +link_directories(/root/TensorComprehensions/build/tc/autotuner) +link_directories(/root/TensorComprehensions/build/tc/proto) + +# torch(aten) lib to delete +link_directories(/root/conda/envs/tc_build/lib/python3.6/site-packages/torch/lib) + +LIST(APPEND SINGA_LINKER_LIBS ${HALIDE_LIBRARIES} tc_aten tc_lang tc_core_cpu tc_cuda tc_core_cuda_no_sdk tc_core tc_autotuner tc_proto ATen) + +### Tensor comprehensions +### Tensor comprehensions +### Tensor comprehensions diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index f511212f1b..8dc1a22244 100755 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -23,6 +23,14 @@ #include #include +#include +#include +#include +#include +#include +#include +#include "tc/core/cuda/cuda_tc_executor.h" + #include "singa/core/common.h" #include "singa/core/device.h" #include "singa/proto/core.pb.h" @@ -147,6 +155,7 @@ class Tensor { /// Return average L2 norm float L2() const; + // -------------------------------------------------------------------------- // ---Following methods changes the internal data // -------------------------------------------------------------------------- @@ -603,6 +612,98 @@ Tensor ConcatRows(const vector &in); Tensor ConcatenateColumns(const vector &in); /// Alias name for function ConcatenateColumns Tensor ConcatColumns(const vector &in); + + + + +/// tc integration start +DLManagedTensor* toDLPack(const Tensor& src); +//Tensor fromDLPack(const DLManagedTensor* src); + +inline std::vector makeDLTensors( + const std::vector& tensors); + +template +std::unique_ptr compileTC( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs, + const typename Backend::MappingOptionsType& options, + const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()); + + +std::vector inferOutputTensorInfo( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs); + +std::vector prepareOutputs( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs); + +template +void runTC( + const Executor& executor, + const std::vector& inputs, + std::vector& outputs); + + +// tensor comprehension operations +Tensor SoftMaxTC(const Tensor &in); +Tensor ReluTC(const Tensor &in); +Tensor MatMulTC(const Tensor &in1,const Tensor &in2); + + +// makeDLConstTensors implementation +inline std::vector makeDLConstTensors(const std::vector& tensors) { + std::vector dlTensors; + for (auto tensor : tensors) { + auto dlMTensor = toDLPack(tensor); + dlTensors.push_back(tc::makeDLConstTensor(&(dlMTensor->dl_tensor))); + dlMTensor->deleter(dlMTensor); + } + return dlTensors; +} + +// makeDLTensors implementation +inline std::vector makeDLTensors( const std::vector& tensors) { + std::vector dlTensors; + for (auto tensor : tensors) { + auto dlMTensor = toDLPack(tensor); + dlTensors.push_back(tc::makeDLTensor(&(dlMTensor->dl_tensor))); + dlMTensor->deleter(dlMTensor); + } + return dlTensors; +} + + +// compile implementation +template +std::unique_ptr compileTC( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs, + const typename Backend::MappingOptionsType& options, + const tc::CompilerOptions& compilerOptions) { + auto inputDLTensors = makeDLConstTensors(inputs); + return tc::compile( + tc, entryPoint, extractRawPtrs(inputDLTensors), options, compilerOptions); +} + +// run implementation +template +void runTC( + const Executor& executor, + const std::vector& inputs, + std::vector& outputs) { + auto inputDLTensors = makeDLConstTensors(inputs); + auto outputDLTensors = makeDLTensors(outputs); + return executor.run( extractRawPtrs(inputDLTensors), extractRawPtrs(outputDLTensors)); +} + +/// tc integration end + } // namespace singa #endif // SINGA_CORE_TENSOR_H_ diff --git a/src/api/core_tensor.i b/src/api/core_tensor.i index a52fb376e7..fbb8cf9276 100755 --- a/src/api/core_tensor.i +++ b/src/api/core_tensor.i @@ -345,4 +345,26 @@ namespace singa{ Tensor CrossEntropyFwd(const Tensor& p, const Tensor& t); Tensor SoftmaxCrossEntropyBwd(const Tensor& p, const Tensor& t); + + /* ============ Tensor Comprehensions ============ */ + /* /root/incubator-singa/build/src/api/singa_wrap.cxx:14938:166: error: use of deleted function */ + /* due to below issue, abort this approach + std::vector prepareOutputs( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs); + + template + void runTC( const Executor& executor, const std::vector& inputs, std::vector& outputs); + %template(runTCCuda) runTC; + + template + std::unique_ptr compileTC( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs, + const typename Backend::MappingOptionsType& options, + const tc::CompilerOptions& compilerOptions = tc::CompilerOptions()); + %template(compileTCCuda) compileTC; + */ } diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index b4cc8edf61..2cb8226be8 100755 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -21,12 +21,23 @@ #include "./tensor_math_cpp.h" #include "./tensor_math_cuda.h" #include "./tensor_math_opencl.h" + #include #include +//#include +//#include +#include "tc/core/check.h" +#include "tc/core/compiler.h" +#include "tc/core/tc_executor.h" +#include "tc/core/tensor.h" #define Noaxis 9999 +// namespace is already exist in singa +// aliasing to avoid duplicates +namespace tclang = lang; + namespace singa { Tensor::~Tensor() { @@ -1334,4 +1345,158 @@ Tensor Reshape(const Tensor &in, const Shape &s) { return out.Reshape(s); } + +/// tc integration start +struct SingaDLManagedTensor { + Tensor handle; + DLManagedTensor tensor; +}; + +void deleter(DLManagedTensor* arg) { + delete static_cast(arg->manager_ctx); +} + +static DLDataType getDLDataType(const Tensor& t) { + DLDataType dtype; + dtype.lanes = 1; + // TODO: get the number of bytes of the datatype + //dtype.bits = t.data_type() * 8; + dtype.bits = 4 * 8; + switch (t.data_type()) { + case kFloat32: + dtype.code = DLDataTypeCode::kDLFloat; + break; + default: + throw std::logic_error("only kFloat32 is supported for dlpack conversion"); + break; + } + return dtype; +} + +static DLContext getDLContext(const Tensor& tensor, const int64_t& device_id) { + DLContext ctx; + ctx.device_id = device_id; + ctx.device_type = DLDeviceType::kDLGPU; + //TODO: fix this + //if (tensor.is_cuda()) { + // ctx.device_type = DLDeviceType::kDLGPU; + //} else { + // ctx.device_type = DLDeviceType::kDLCPU; + //} + return ctx; +} + +// This function returns a shared_ptr to memory managed DLpack tensor +// constructed out of ATen tensor +DLManagedTensor* toDLPack(const Tensor& src) { + SingaDLManagedTensor* singaDLManagedTensor(new SingaDLManagedTensor); + singaDLManagedTensor->handle = src; + singaDLManagedTensor->tensor.manager_ctx = singaDLManagedTensor; + singaDLManagedTensor->tensor.deleter = &deleter; + singaDLManagedTensor->tensor.dl_tensor.data = src.block()->mutable_data(); + int64_t device_id = 0; + // TODO: fix this + //if (src.is_cuda()) { + // device_id = src.get_device(); + //} + singaDLManagedTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id); + singaDLManagedTensor->tensor.dl_tensor.ndim = src.nDim(); + singaDLManagedTensor->tensor.dl_tensor.dtype = getDLDataType(src); + + auto shapeVec = new std::vector(src.shape().begin(),src.shape().end()); + singaDLManagedTensor->tensor.dl_tensor.shape = shapeVec->data(); + + auto strideVec = new std::vector(src.stride().begin(),src.stride().end()); + singaDLManagedTensor->tensor.dl_tensor.strides = strideVec->data(); + + singaDLManagedTensor->tensor.dl_tensor.byte_offset = 0; + return &(singaDLManagedTensor->tensor); +} + +// prepare output +std::vector inferOutputTensorInfo( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs) { + auto parsedTcs = tc::detail::parse(tc); + if (parsedTcs.count(entryPoint) != 1u) { + TC_CHECK_GE(parsedTcs.size(), 1u) + << "No TC was parsed, should have thrown earlier"; + throw tclang::ErrorReport(parsedTcs.begin()->second) + << "\nattempting to access undefined entryPoint: " << entryPoint; + } + auto inputDLTensors = makeDLConstTensors(inputs); + return makeDLTensorVector(tc::detail::inferOutputTensorInfo(parsedTcs.at(entryPoint), extractRawPtrs(inputDLTensors))); +} + +std::vector prepareOutputs( + const std::string& tc, + const std::string& entryPoint, + const std::vector& inputs) { + std::vector outputs; + auto outTensorInfo = inferOutputTensorInfo(tc, entryPoint, inputs); + if (outTensorInfo.size() == 0) { + return outputs; + } + TC_CHECK_GE(inputs.size(), 1u) + << "NYI: Need >= 1 input tensors to determine " + << "backend and prepare ATen outputs. Add an overload with just an ATen " + << "backend"; + + auto dev = inputs[0].device(); + auto dtype = inputs[0].data_type(); + for (size_t i = 0; i < outTensorInfo.size(); ++i) { + tc::TensorInfo info(outTensorInfo[i]); + Shape shape(info.shape.begin(), info.shape.end()); + + Tensor tmp(shape, dev, dtype); + outputs.push_back(tmp); + } + return outputs; +} + + +// examples of TC operations +Tensor SoftMaxTC(const Tensor &in) { + std::string tc= R"TC( +def softmax(float(N, D) I) -> (O, expsum) { + expsum(n) +=! exp(I(n, d)) + O(n, d) = exp(I(n, d)) / expsum(n) +} +)TC"; + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + auto pExecutor = singa::compileTC(tc, "softmax", {in}, {naiveOptions}); + auto outputs = singa::prepareOutputs(tc, "softmax", {in}); + singa::runTC(*pExecutor, {in}, outputs); + return outputs[0]; +} + +Tensor ReluTC(const Tensor &in) { + std::string tc = R"TC( +def relu(float(B,M) I) -> (O1){ + O1(b, m) = fmax(I(b, m), 0) +} + )TC"; + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + auto pExecutor = singa::compileTC(tc, "relu", {in}, {naiveOptions}); + auto outputs = singa::prepareOutputs(tc, "relu", {in}); + singa::runTC(*pExecutor, {in}, outputs); + return outputs[0]; +} + +Tensor MatMulTC(const Tensor &in1,const Tensor &in2) { + std::string tc = R"TC( +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(i, j) +=! A(i, kk) * B(kk, j) +} + )TC"; + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + auto pExecutor = singa::compileTC(tc, "matmul", {in1, in2}, {naiveOptions}); + auto outputs = singa::prepareOutputs(tc, "matmul", {in1, in2}); + singa::runTC(*pExecutor, {in1, in2}, outputs); + return outputs[0]; +} +/// tc integration end + + } // namespace singa diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a056631987..f9344f7064 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,6 +19,8 @@ INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) INCLUDE_DIRECTORIES(${CMAKE_BINARY_DIR}/include) + + IF(ENABLE_DIST) ADD_EXECUTABLE(test_ep "singa/test_ep.cc") ADD_DEPENDENCIES(test_ep singa) @@ -33,7 +35,7 @@ LIST(REMOVE_ITEM singa_test_source "singa/test_ep.cc") ADD_EXECUTABLE(test_singa "gtest/gtest_main.cc" ${singa_test_source}) ADD_DEPENDENCIES(test_singa singa) #MESSAGE(STATUS "link libs" ${singa_linker_libs}) -TARGET_LINK_LIBRARIES(test_singa gtest singa ) +TARGET_LINK_LIBRARIES(test_singa gtest singa ${SINGA_LINKER_LIBS}) IF(UNIX AND (NOT APPLE)) LIST(APPEND LINK_FLAGS "-pthread") ENDIF() diff --git a/test/singa/test_tensor_math.cc b/test/singa/test_tensor_math.cc index 367e49b13a..5ca332c550 100644 --- a/test/singa/test_tensor_math.cc +++ b/test/singa/test_tensor_math.cc @@ -18,6 +18,22 @@ #include "gtest/gtest.h" #include "singa/core/tensor.h" + +#include + +// tensor comprehensions +#include "tc/core/cuda/cuda_mapping_options.h" +#include +#include "tc/examples/common.h" +#include "tc/aten/aten.h" +#include "tc/aten/aten_autotuner.h" +#include "tc/aten/aten_compiler.h" +#include "tc/autotuner/genetic_search.h" +#include "tc/core/check.h" +#include "tc/core/cuda/cuda_tc_executor.h" +#include "tc/core/flags.h" +// tensor comprehensions + using singa::Tensor; using singa::Shape; using singa::Device; @@ -40,6 +56,137 @@ class TensorMath : public ::testing::Test { const float dat2[6] = {1.1f, 2.1f, 3.1f, 4.1f, 5.1f, 6.1f}; }; +// tensor comprehensions starts +// smoke test on ATen tensordot +TEST_F(TensorMath, TCATenTensordot) { + std::string tc = R"TC( +def tensordot(float(N, C1, C2, H, W) I0, + float(N, C2, C3, H, W) I1) -> (O) +{ + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} + )TC"; + + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + + at::Tensor I0 = makeATenTensor({16, 8, 16, 17, 25}); + at::Tensor I1 = makeATenTensor({16, 16, 2, 17, 25}); + + auto pExecutor = tc::aten::compile(tc, "tensordot", {I0, I1}, {naiveOptions}); + auto outputs = tc::aten::prepareOutputs(tc, "tensordot", {I0, I1}); + + tc::aten::run(*pExecutor, {I0, I1}, outputs); +} + +// compare dlpack tensor conversion between aten, singa +TEST_F(TensorMath, TCToDLPack) { + at::Tensor I0 = makeATenTensor({3,4,5}); + auto dl_target = at::toDLPack(I0); + + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{3,4,5}, cuda); + t1.SetValue(1.1f); + auto dl_output = toDLPack(t1); + + EXPECT_EQ( dl_target->dl_tensor.ndim , dl_output->dl_tensor.ndim); + EXPECT_EQ( dl_target->dl_tensor.dtype.code , dl_output->dl_tensor.dtype.code); + EXPECT_EQ( dl_target->dl_tensor.dtype.bits , dl_output->dl_tensor.dtype.bits); + EXPECT_EQ( dl_target->dl_tensor.dtype.lanes , dl_output->dl_tensor.dtype.lanes); + EXPECT_EQ( dl_target->dl_tensor.shape[0] , dl_output->dl_tensor.shape[0]); + EXPECT_EQ( dl_target->dl_tensor.shape[1] , dl_output->dl_tensor.shape[1]); + EXPECT_EQ( dl_target->dl_tensor.shape[2] , dl_output->dl_tensor.shape[2]); + EXPECT_EQ( dl_target->dl_tensor.strides[0] , dl_output->dl_tensor.strides[0]); + EXPECT_EQ( dl_target->dl_tensor.strides[1] , dl_output->dl_tensor.strides[1]); + EXPECT_EQ( dl_target->dl_tensor.strides[2] , dl_output->dl_tensor.strides[2]); + EXPECT_EQ( dl_target->dl_tensor.byte_offset , dl_output->dl_tensor.byte_offset); +} + +TEST_F(TensorMath, TCTensordot) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{16, 8, 16, 17, 25}, cuda); + singa::Tensor t2(singa::Shape{16, 16, 2, 17, 25}, cuda); + + t1.SetValue(1.1f); + t2.SetValue(1.2f); + + std::string tc = R"TC( +def tensordot(float(N, C1, C2, H, W) I0, + float(N, C2, C3, H, W) I1) -> (O) +{ + O(n, c1, c3, h, w) +=! I0(n, c1, r_c2, h, w) * I1(n, r_c2, c3, h, w) +} + )TC"; + + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + + auto pExecutor = singa::compileTC(tc, "tensordot", {t1, t2}, {naiveOptions}); + auto outputs = singa::prepareOutputs(tc, "tensordot", {t1, t2}); + singa::runTC(*pExecutor, {t1, t2}, outputs); +} + +TEST_F(TensorMath, TCRelu) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{2,2}, cuda); + + const float dat1[4] = {-1.0f, 1.0f, -2.0f, 3.0f}; + t1.CopyDataFromHostPtr(dat1, 4); + + auto o1 = ReluTC(t1).ToHost(); + EXPECT_EQ(o1.shape(0), 2); + EXPECT_EQ(o1.shape(1), 2); + const float *dptr = o1.data(); + EXPECT_FLOAT_EQ(0.0f, dptr[0]); + EXPECT_FLOAT_EQ(1.0f, dptr[1]); + EXPECT_FLOAT_EQ(0.0f, dptr[2]); + EXPECT_FLOAT_EQ(3.0f, dptr[3]); +} + +TEST_F(TensorMath, TCMatmul) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{2,2}, cuda); + singa::Tensor t2(singa::Shape{2,2}, cuda); + t1.SetValue(1.1f); + t2.SetValue(1.2f); + + auto o1 = MatMulTC(t1,t2).ToHost(); + EXPECT_EQ(o1.shape(0), 2); + EXPECT_EQ(o1.shape(1), 2); + const float *dptr = o1.data(); + EXPECT_FLOAT_EQ(2.64f, dptr[0]); + EXPECT_FLOAT_EQ(2.64f, dptr[1]); + EXPECT_FLOAT_EQ(2.64f, dptr[2]); + EXPECT_FLOAT_EQ(2.64f, dptr[3]); +} + +/* TODO: segment fault +TEST_F(TensorMath, TCSoftmax) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{2}, cuda); + + const float dat1[6] = {1.0f, 2.0f}; + t1.CopyDataFromHostPtr(dat1, 2); + + auto output=SoftMaxTC(t1); + output.ToHost(); + + auto optr1=output.data(); + + EXPECT_EQ(output.shape(0), 6); + const float *dptr1 = output.data(); + float sum = 0; + for (int i = 0; i < 6; i++) sum += (float)exp(i + 1); + + EXPECT_NEAR(exp(1) / sum, dptr1[0], 1e-5); + EXPECT_NEAR(exp(3) / sum, dptr1[2], 1e-5); + EXPECT_NEAR(exp(5) / sum, dptr1[4], 1e-5); + EXPECT_NEAR(exp(2) / sum, dptr1[1], 1e-5); + EXPECT_NEAR(exp(4) / sum, dptr1[3], 1e-5); + EXPECT_NEAR(exp(6) / sum, dptr1[5], 1e-5); +} +*/ +// Tensor comprehensions ends + + TEST_F(TensorMath, AbsCpp) { Tensor aa = a.Clone(); Tensor bb = b.Clone(); diff --git a/tool/docker/devel/ubuntu/cuda9/Dockerfile b/tool/docker/devel/ubuntu/cuda9/Dockerfile index 98a0a88dfa..fabf9a4a2a 100644 --- a/tool/docker/devel/ubuntu/cuda9/Dockerfile +++ b/tool/docker/devel/ubuntu/cuda9/Dockerfile @@ -58,14 +58,14 @@ RUN apt-get update \ # install swig > 3.0.10 RUN wget http://prdownloads.sourceforge.net/swig/swig-3.0.10.tar.gz -P /tmp/ \ && tar zxf /tmp/swig-3.0.10.tar.gz -C /tmp/ \ - && cd /tmp/swig-3.0.10 && ./configure && make && make install + && cd /tmp/swig-3.0.10 && ./configure && make -j8 && make install # install mkldnn RUN wget https://github.com/intel/mkl-dnn/archive/v0.18.tar.gz -P /tmp/ \ && tar zxf /tmp/v0.18.tar.gz -C /tmp/ \ && cd /tmp/mkl-dnn-0.18/ \ && mkdir -p build && cd build && cmake .. \ - && make && make install + && make -j8 && make install # config ssh service RUN mkdir /var/run/sshd \ @@ -84,6 +84,61 @@ RUN git clone https://github.com/apache/incubator-singa.git $HOME/incubator-sing && cmake -DENABLE_TEST=ON -DUSE_CUDA=ON -DUSE_MKLDNN=ON -DUSE_PYTHON3=ON .. RUN cd $HOME/incubator-singa/build && make && make install + +# start of TC +RUN apt-get update && apt-get install -y libgmp3-dev cmake automake libtool + +RUN wget https://repo.anaconda.com/archive/Anaconda3-5.1.0-Linux-x86_64.sh -O anaconda.sh && \ + chmod +x anaconda.sh && \ + ./anaconda.sh -b -p /root/conda && \ + rm anaconda.sh + +RUN . /root/conda/bin/activate && \ + conda update -y -n base conda && \ + conda create -y --name tc_build python=3.6 && \ + conda activate tc_build && \ + conda install -y pyyaml mkl-include pytest && \ + conda install -y -c nicolasvasilache llvm-trunk halide && \ + conda install -y -c pytorch pytorch=0.4.0 torchvision cuda90 && \ + conda remove -y cudatoolkit --force + +RUN cd /root/ && git clone https://github.com/dcslin/TensorComprehensions.git --recursive && \ + cd TensorComprehensions + +ENV CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda + +RUN apt-get remove -y libprotobuf-dev protobuf-compiler && \ + apt-get update && \ + apt-get install -y unzip automake curl less + +ENV CLANG_PREFIX /root/conda/envs/tc_build +RUN cd /root/TensorComprehensions && . /root/conda/bin/activate && conda activate tc_build && ./build.sh + +############# singa build +# protobuf +RUN cd /root/TensorComprehensions/third-party/googlelibraries/protobuf-3.5.2 \ + && ./autogen.sh \ + && ./configure \ + && make -j8 \ + && make check \ + && make install \ + && ldconfig + +# isl +RUN cd /root/TensorComprehensions/third-party/islpp \ + && ./autogen.sh \ + && ./configure \ + && make -j8 \ + && make install + +# gflags +RUN cd /root/TensorComprehensions/third-party/googlelibraries/gflags \ + && mkdir build && cd build \ + && cmake .. \ + && make -j8 \ + && make install +# end of TC + WORKDIR $HOME/incubator-singa EXPOSE 22