From 050674567552110cb9d63c570e3f04146a96b36d Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 11 Aug 2019 02:55:19 +0000 Subject: [PATCH] tc comprehension integration Ref. SINGA-482 --- cmake/Dependencies.cmake | 40 ++++ include/singa/core/tensor.h | 89 +++++++++ src/core/tensor/tensor.cc | 115 +++++++++++ src/model/operation/tc_fn.cc | 40 ++++ src/model/operation/tc_fn.h | 39 ++++ test/CMakeLists.txt | 2 +- test/singa/test_operation_tc_fn.cc | 221 ++++++++++++++++++++++ test/singa/test_softmax.cc | 3 + tool/docker/devel/ubuntu/cuda9/Dockerfile | 30 ++- 9 files changed, 576 insertions(+), 3 deletions(-) create mode 100644 src/model/operation/tc_fn.cc create mode 100644 src/model/operation/tc_fn.h create mode 100644 test/singa/test_operation_tc_fn.cc diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index d1d8060171..201138c62d 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -149,3 +149,43 @@ IF(USE_MKLDNN) INCLUDE_DIRECTORIES(${MKLDNN_INCLUDE_DIR}) LIST(APPEND SINGA_LINKER_LIBS ${MKLDNN_LIBRARIES}) ENDIF() + + +### Tensor comprehensions +INCLUDE_DIRECTORIES(/root/TensorComprehensions) +INCLUDE_DIRECTORIES(/root/TensorComprehensions/tc/version) +INCLUDE_DIRECTORIES(/root/TensorComprehensions/build) +# polyhedral model required +INCLUDE_DIRECTORIES(/root/TensorComprehensions/isl_interface/include) +# dlpack +INCLUDE_DIRECTORIES(/root/TensorComprehensions/third-party/dlpack/include) +# islpp +INCLUDE_DIRECTORIES(/root/TensorComprehensions/third-party/islpp/include) +# gflags +INCLUDE_DIRECTORIES(/root/TensorComprehensions/build/third-party/googlelibraries/gflags/include) +# glog +INCLUDE_DIRECTORIES(/root/TensorComprehensions/build/third-party/googlelibraries/glog) +# Halide +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/include/Halide) +# llvm +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/include) +# torch ATen header +INCLUDE_DIRECTORIES(/root/conda/envs/tc_build/lib/python3.6/site-packages/torch/lib/include) + +# find Halide lib +set(HALIDE_PREFIX "/root/conda/envs/tc_build") +find_library(HALIDE_LIBRARIES REQUIRED NAMES Halide PATHS ${HALIDE_PREFIX} PATH_SUFFIXES lib lib64 NO_DEFAULT_PATH) +message(STATUS "Found Halide.so file: ${HALIDE_LIBRARIES}") + +# find tc lib +link_directories(/root/TensorComprehensions/build/tc/aten) +link_directories(/root/TensorComprehensions/build/tc/lang) +link_directories(/root/TensorComprehensions/build/tc/core) +link_directories(/root/TensorComprehensions/build/tc/autotuner) +link_directories(/root/TensorComprehensions/build/tc/proto) + +# torch(aten) +link_directories(/root/conda/envs/tc_build/lib/python3.6/site-packages/torch/lib) + +LIST(APPEND SINGA_LINKER_LIBS ${HALIDE_LIBRARIES} tc_aten tc_lang tc_core_cpu tc_cuda tc_core_cuda_no_sdk tc_core tc_autotuner tc_proto ATen) +### Tensor comprehensions diff --git a/include/singa/core/tensor.h b/include/singa/core/tensor.h index f511212f1b..965ba04dcf 100755 --- a/include/singa/core/tensor.h +++ b/include/singa/core/tensor.h @@ -23,6 +23,18 @@ #include #include +// tc +#include +#include +#include +#include +#include +#include +#include +#include +#include +// tc + #include "singa/core/common.h" #include "singa/core/device.h" #include "singa/proto/core.pb.h" @@ -603,6 +615,83 @@ Tensor ConcatRows(const vector &in); Tensor ConcatenateColumns(const vector &in); /// Alias name for function ConcatenateColumns Tensor ConcatColumns(const vector &in); + + + + +/// tc integration start +DLManagedTensor *toDLPack(const Tensor &src); + +inline std::vector +makeDLTensors(const std::vector &tensors); + +template +std::unique_ptr +compileTC(const std::string &tc, const std::string &entryPoint, + const std::vector &inputs, + const typename Backend::MappingOptionsType &options, + const tc::CompilerOptions &compilerOptions = tc::CompilerOptions()); + +std::vector +inferOutputTensorInfo(const std::string &tc, const std::string &entryPoint, + const std::vector &inputs); + +std::vector prepareOutputs(const std::string &tc, + const std::string &entryPoint, + const std::vector &inputs); + +template +void runTC(const Executor &executor, const std::vector &inputs, + std::vector &outputs); + +// makeDLConstTensors implementation +inline std::vector +makeDLConstTensors(const std::vector &tensors) { + std::vector dlTensors; + for (auto tensor : tensors) { + auto dlMTensor = toDLPack(tensor); + dlTensors.push_back(tc::makeDLConstTensor(&(dlMTensor->dl_tensor))); + dlMTensor->deleter(dlMTensor); + } + return dlTensors; +} + +// makeDLTensors implementation +inline std::vector +makeDLTensors(const std::vector &tensors) { + std::vector dlTensors; + for (auto tensor : tensors) { + auto dlMTensor = toDLPack(tensor); + dlTensors.push_back(tc::makeDLTensor(&(dlMTensor->dl_tensor))); + dlMTensor->deleter(dlMTensor); + } + return dlTensors; +} + +// compile implementation +template +std::unique_ptr +compileTC(const std::string &tc, const std::string &entryPoint, + const std::vector &inputs, + const typename Backend::MappingOptionsType &options, + const tc::CompilerOptions &compilerOptions) { + auto inputDLTensors = makeDLConstTensors(inputs); + return tc::compile(tc, entryPoint, extractRawPtrs(inputDLTensors), + options, compilerOptions); +} + +// run implementation +template +void runTC(const Executor &executor, const std::vector &inputs, + std::vector &outputs) { + auto inputDLTensors = makeDLConstTensors(inputs); + auto outputDLTensors = makeDLTensors(outputs); + return executor.run(extractRawPtrs(inputDLTensors), + extractRawPtrs(outputDLTensors)); +} + +/// tc integration end + } // namespace singa #endif // SINGA_CORE_TENSOR_H_ diff --git a/src/core/tensor/tensor.cc b/src/core/tensor/tensor.cc index b4cc8edf61..4b63711397 100755 --- a/src/core/tensor/tensor.cc +++ b/src/core/tensor/tensor.cc @@ -24,9 +24,19 @@ #include #include +// tc +#include +#include +#include +#include +// tc #define Noaxis 9999 +// namespace is already exist in singa +// aliasing to avoid duplicates +namespace tclang = lang; + namespace singa { Tensor::~Tensor() { @@ -1334,4 +1344,109 @@ Tensor Reshape(const Tensor &in, const Shape &s) { return out.Reshape(s); } + +/// tc integration start +struct SingaDLManagedTensor { + Tensor handle; + DLManagedTensor tensor; +}; + +void deleter(DLManagedTensor *arg) { + delete static_cast(arg->manager_ctx); +} + +static DLDataType getDLDataType(const Tensor &t) { + DLDataType dtype; + dtype.lanes = 1; + dtype.bits = SizeOf(t.data_type()) * 8; + switch (t.data_type()) { + case kFloat32: + dtype.code = DLDataTypeCode::kDLFloat; + break; + default: + throw std::logic_error("only kFloat32 is supported for dlpack conversion"); + break; + } + return dtype; +} + +static DLContext getDLContext(const Tensor &tensor, const int64_t &device_id) { + DLContext ctx; + ctx.device_id = device_id; + if (tensor.device()->lang() == kCuda) { + ctx.device_type = DLDeviceType::kDLGPU; + } else { + ctx.device_type = DLDeviceType::kDLCPU; + } + return ctx; +} + +// This function returns a shared_ptr to memory managed DLpack tensor +// constructed out of ATen tensor +DLManagedTensor *toDLPack(const Tensor &src) { + SingaDLManagedTensor *singaDLManagedTensor(new SingaDLManagedTensor); + singaDLManagedTensor->handle = src; + singaDLManagedTensor->tensor.manager_ctx = singaDLManagedTensor; + singaDLManagedTensor->tensor.deleter = &deleter; + singaDLManagedTensor->tensor.dl_tensor.data = src.block()->mutable_data(); + int64_t device_id = src.device()->id(); + singaDLManagedTensor->tensor.dl_tensor.ctx = getDLContext(src, device_id); + singaDLManagedTensor->tensor.dl_tensor.ndim = src.nDim(); + singaDLManagedTensor->tensor.dl_tensor.dtype = getDLDataType(src); + + auto shapeVec = + new std::vector(src.shape().begin(), src.shape().end()); + singaDLManagedTensor->tensor.dl_tensor.shape = shapeVec->data(); + + auto strideVec = + new std::vector(src.stride().begin(), src.stride().end()); + singaDLManagedTensor->tensor.dl_tensor.strides = strideVec->data(); + + singaDLManagedTensor->tensor.dl_tensor.byte_offset = 0; + return &(singaDLManagedTensor->tensor); +} + +// prepare output +std::vector +inferOutputTensorInfo(const std::string &tc, const std::string &entryPoint, + const std::vector &inputs) { + auto parsedTcs = tc::detail::parse(tc); + if (parsedTcs.count(entryPoint) != 1u) { + TC_CHECK_GE(parsedTcs.size(), 1u) + << "No TC was parsed, should have thrown earlier"; + throw tclang::ErrorReport(parsedTcs.begin()->second) + << "\nattempting to access undefined entryPoint: " << entryPoint; + } + auto inputDLTensors = makeDLConstTensors(inputs); + return makeDLTensorVector(tc::detail::inferOutputTensorInfo( + parsedTcs.at(entryPoint), extractRawPtrs(inputDLTensors))); +} + +std::vector prepareOutputs(const std::string &tc, + const std::string &entryPoint, + const std::vector &inputs) { + std::vector outputs; + auto outTensorInfo = inferOutputTensorInfo(tc, entryPoint, inputs); + if (outTensorInfo.size() == 0) { + return outputs; + } + TC_CHECK_GE(inputs.size(), 1u) + << "NYI: Need >= 1 input tensors to determine " + << "backend and prepare ATen outputs. Add an overload with just an ATen " + << "backend"; + + auto dev = inputs[0].device(); + auto dtype = inputs[0].data_type(); + for (size_t i = 0; i < outTensorInfo.size(); ++i) { + tc::TensorInfo info(outTensorInfo[i]); + Shape shape(info.shape.begin(), info.shape.end()); + + Tensor tmp(shape, dev, dtype); + outputs.push_back(tmp); + } + return outputs; +} +/// tc integration end + + } // namespace singa diff --git a/src/model/operation/tc_fn.cc b/src/model/operation/tc_fn.cc new file mode 100644 index 0000000000..52ba592fe4 --- /dev/null +++ b/src/model/operation/tc_fn.cc @@ -0,0 +1,40 @@ +/********************************************************* +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +************************************************************/ +#include "./tc_fn.h" + +namespace singa { + +TcFnHandle::TcFnHandle(std::string tcDefinition, std::string entryFn, const std::vector &inputs) +{ + tc_string = tcDefinition; + tc_name = entryFn; + auto naiveOptions = tc::CudaBackend::MappingOptionsType::makeNaiveMappingOptions(); + pExecutor = singa::compileTC(tcDefinition, entryFn, inputs, {naiveOptions}); +}; + +Tensor tcExecute(const TcFnHandle &tcFnhandle, const std::vector &inputs) +{ + auto outputs = singa::prepareOutputs(tcFnhandle.tc_string, tcFnhandle.tc_name, inputs); + singa::runTC(*(tcFnhandle.pExecutor), inputs, outputs); + return outputs[0]; +} + +} diff --git a/src/model/operation/tc_fn.h b/src/model/operation/tc_fn.h new file mode 100644 index 0000000000..c02e967b81 --- /dev/null +++ b/src/model/operation/tc_fn.h @@ -0,0 +1,39 @@ +/********************************************************* +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +************************************************************/ +//#ifndef SINGA_MODEL_OPERATION_TC_FN_H_ +//#define SINGA_MODEL_OPERATION_TC_FN_H_ + + +#include "singa/core/tensor.h" + +namespace singa { + +class TcFnHandle { + public: + TcFnHandle(std::string tcDefinition, std::string entryFn, const std::vector &inputs); + std::string tc_string; + std::string tc_name; + std::unique_ptr pExecutor; +}; + +Tensor tcExecute(const TcFnHandle &smh, const std::vector &inputs); + +} // namespace singa diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a056631987..683e092175 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -33,7 +33,7 @@ LIST(REMOVE_ITEM singa_test_source "singa/test_ep.cc") ADD_EXECUTABLE(test_singa "gtest/gtest_main.cc" ${singa_test_source}) ADD_DEPENDENCIES(test_singa singa) #MESSAGE(STATUS "link libs" ${singa_linker_libs}) -TARGET_LINK_LIBRARIES(test_singa gtest singa ) +TARGET_LINK_LIBRARIES(test_singa gtest singa ${SINGA_LINKER_LIBS}) IF(UNIX AND (NOT APPLE)) LIST(APPEND LINK_FLAGS "-pthread") ENDIF() diff --git a/test/singa/test_operation_tc_fn.cc b/test/singa/test_operation_tc_fn.cc new file mode 100644 index 0000000000..51c27c98dc --- /dev/null +++ b/test/singa/test_operation_tc_fn.cc @@ -0,0 +1,221 @@ +/********************************************************* +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +************************************************************/ + +#include "../src/model/operation/tc_fn.h" +#include "gtest/gtest.h" +#include +#include + +using namespace singa; +TEST(OperationTCFn, SoftmaxForward) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{1,2}, cuda); + + const float dat1[2] = {1.0f, 2.0f}; + t1.CopyDataFromHostPtr(dat1, 2); + + std::string tc = R"TC( +def softmax(float(N, D) I) -> (O, expsum, maxVal) { + maxVal(n) max=! I(n, d) + expsum(n) +=! exp(I(n, d) - maxVal(n)) + O(n, d) = exp(I(n, d) - maxVal(n)) / expsum(n) +} +)TC"; + TcFnHandle tfh(tc, "softmax", {t1}); + + std::chrono::steady_clock::time_point beginTC = std::chrono::steady_clock::now(); + Tensor output = tcExecute(tfh, {t1}); + std::chrono::steady_clock::time_point endTC = std::chrono::steady_clock::now(); + std::cout << "\nTime " << std::chrono::duration_cast(endTC - beginTC).count() << "[microseconds]" << std::endl; + + output.ToHost(); + + EXPECT_EQ(output.shape(0), 1); + EXPECT_EQ(output.shape(1), 2); + const float *dptr1 = output.data(); + EXPECT_NEAR(0.26894142f, dptr1[0], 1e-5); + EXPECT_NEAR(0.73105858f, dptr1[1], 1e-5); +} + +TEST(OperationTCFn, SoftmaxBackward) { + const float x[] = {1.f, 2.f, 0.f, -2.f, -3.f, -1.f}; + const float grad[] = {2.0f, -3.0f, 1.0f, 3.0f, -1.0f, -2.0f}; + + size_t n = sizeof(x) / sizeof(float); + size_t batch = 2, c = 3; + + singa::Shape shape = {batch, c}; + auto cuda = std::make_shared(); + + singa::Tensor in(shape, cuda); + in.CopyDataFromHostPtr(x, n); + singa::Tensor output_grad(shape, cuda); + output_grad.CopyDataFromHostPtr(grad, n); + + + std::string tc_forward_def = R"TC( +def softmax(float(N, D) I) -> (O, expsum, maxVal) { + maxVal(n) max=! I(n, d) + expsum(n) +=! exp(I(n, d) - maxVal(n)) + O(n, d) = exp(I(n, d) - maxVal(n)) / expsum(n) +} +)TC"; + TcFnHandle tfh_forward(tc_forward_def, "softmax", {in}); + Tensor output = tcExecute(tfh_forward, {in}); + + + std::string tc = R"TC( +def softmax_bwd(float(N, D) output, float(N, D) grad_output) -> (grad_input, sigma) +{ + sigma(n) +=! output(n, d) * grad_output(n ,d) + grad_input(n, d) = ( grad_output(n, d) - sigma(n) ) * output(n, d) +} +)TC"; + TcFnHandle tfh(tc, "softmax_bwd", {output, output_grad}); + + + std::chrono::steady_clock::time_point beginTC = std::chrono::steady_clock::now(); + Tensor in_grad = tcExecute(tfh, {output, output_grad}); + std::chrono::steady_clock::time_point endTC = std::chrono::steady_clock::now(); + std::cout << "\nTime " << std::chrono::duration_cast(endTC - beginTC).count() << "[microseconds]" << std::endl; + + + in_grad.ToHost(); + const float *xptr = in_grad.data(); + + output.ToHost(); + const float* yptr = output.data(); + + float* dx = new float[n]; + float* sigma = new float[batch]; + for (size_t i = 0; i < batch; i++) + sigma[i] = 0.f; + for (size_t i = 0; i < n; i++) + sigma[i / c] += grad[i] * yptr[i]; + for (size_t i = 0; i < batch; i++) + for (size_t j = 0; j < c; j++) + dx[i * c + j] = (grad[i * c + j] - sigma[i]) * yptr[i * c +j]; + EXPECT_FLOAT_EQ(dx[0], xptr[0]); + EXPECT_FLOAT_EQ(dx[1], xptr[1]); + EXPECT_FLOAT_EQ(dx[2], xptr[2]); + EXPECT_FLOAT_EQ(dx[3], xptr[3]); + EXPECT_FLOAT_EQ(dx[4], xptr[4]); + EXPECT_FLOAT_EQ(dx[5], xptr[5]); +} + +TEST(OperationTCFn, ReLU) { + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{2, 2}, cuda); + + const float dat1[4] = {-1.0f, 1.0f, -2.0f, 3.0f}; + t1.CopyDataFromHostPtr(dat1, 4); + + + std::string tc = R"TC( +def relu(float(B,M) I) -> (O1){ + O1(b, m) = fmax(I(b, m), 0) +} + )TC"; + TcFnHandle tfh(tc, "relu", {t1}); + + std::chrono::steady_clock::time_point beginTC = std::chrono::steady_clock::now(); + Tensor o1 = tcExecute(tfh, {t1}); + std::chrono::steady_clock::time_point endTC = std::chrono::steady_clock::now(); + std::cout << "\nTime " << std::chrono::duration_cast(endTC - beginTC).count() << "[microseconds]" << std::endl; + o1.ToHost(); + + EXPECT_EQ(o1.shape(0), 2); + EXPECT_EQ(o1.shape(1), 2); + const float *dptr = o1.data(); + EXPECT_FLOAT_EQ(0.0f, dptr[0]); + EXPECT_FLOAT_EQ(1.0f, dptr[1]); + EXPECT_FLOAT_EQ(0.0f, dptr[2]); + EXPECT_FLOAT_EQ(3.0f, dptr[3]); +} + +TEST(OperationTCFn, FC) { + std::string tc = R"TC( +def fc(float(B,M) I, float(N,M) W1, float(N) B1) -> (O1) { + O1(b, n) +=! I(b, m) * W1(n, m) + O1(b, n) = O1(b, n) + B1(n) +} + )TC"; + + auto cuda = std::make_shared(); + singa::Tensor x(singa::Shape{2, 3}, cuda); + singa::Tensor W(singa::Shape{4, 3}, cuda); + singa::Tensor b(singa::Shape{4}, cuda); + x.SetValue(1.1f); + W.SetValue(1.2f); + b.SetValue(1.3f); + + + TcFnHandle tfh(tc, "fc", {x,W,b}); + + std::chrono::steady_clock::time_point beginTC = std::chrono::steady_clock::now(); + Tensor o1 = tcExecute(tfh, {x,W,b}); + std::chrono::steady_clock::time_point endTC = std::chrono::steady_clock::now(); + std::cout << "\nTime " << std::chrono::duration_cast(endTC - beginTC).count() << "[microseconds]" << std::endl; + o1.ToHost(); + + EXPECT_EQ(o1.shape(0), 2); + EXPECT_EQ(o1.shape(1), 4); + const float *dptr = o1.data(); + EXPECT_FLOAT_EQ(5.26f, dptr[0]); + EXPECT_FLOAT_EQ(5.26f, dptr[1]); + EXPECT_FLOAT_EQ(5.26f, dptr[2]); + EXPECT_FLOAT_EQ(5.26f, dptr[3]); + EXPECT_FLOAT_EQ(5.26f, dptr[4]); + EXPECT_FLOAT_EQ(5.26f, dptr[5]); + EXPECT_FLOAT_EQ(5.26f, dptr[6]); + EXPECT_FLOAT_EQ(5.26f, dptr[7]); +} + +TEST(OperationTCFn, MatMul) { + std::string tc = R"TC( +def matmul(float(M,N) A, float(N,K) B) -> (output) { + output(i, j) +=! A(i, kk) * B(kk, j) +} + )TC"; + + auto cuda = std::make_shared(); + singa::Tensor t1(singa::Shape{2, 2}, cuda); + singa::Tensor t2(singa::Shape{2, 2}, cuda); + t1.SetValue(1.1f); + t2.SetValue(1.2f); + + TcFnHandle tfh(tc, "matmul", {t1,t2}); + + std::chrono::steady_clock::time_point beginTC = std::chrono::steady_clock::now(); + Tensor o1 = tcExecute(tfh, {t1,t2}); + std::chrono::steady_clock::time_point endTC = std::chrono::steady_clock::now(); + std::cout << "\nTime " << std::chrono::duration_cast(endTC - beginTC).count() << "[microseconds]" << std::endl; + o1.ToHost(); + + EXPECT_EQ(o1.shape(0), 2); + EXPECT_EQ(o1.shape(1), 2); + const float *dptr = o1.data(); + EXPECT_FLOAT_EQ(2.64f, dptr[0]); + EXPECT_FLOAT_EQ(2.64f, dptr[1]); + EXPECT_FLOAT_EQ(2.64f, dptr[2]); + EXPECT_FLOAT_EQ(2.64f, dptr[3]); + +} diff --git a/test/singa/test_softmax.cc b/test/singa/test_softmax.cc index 8064b80984..f20552d21b 100644 --- a/test/singa/test_softmax.cc +++ b/test/singa/test_softmax.cc @@ -95,6 +95,9 @@ TEST(Softmax, Backward) { for (size_t j = 0; j < col; j++) dx[i * col + j] = (grad[i * col + j] - sigma[i]) * yptr[i * col +j]; EXPECT_FLOAT_EQ(dx[0], xptr[0]); + EXPECT_FLOAT_EQ(dx[1], xptr[1]); + EXPECT_FLOAT_EQ(dx[2], xptr[2]); + EXPECT_FLOAT_EQ(dx[3], xptr[3]); EXPECT_FLOAT_EQ(dx[4], xptr[4]); EXPECT_FLOAT_EQ(dx[5], xptr[5]); delete[] dx; diff --git a/tool/docker/devel/ubuntu/cuda9/Dockerfile b/tool/docker/devel/ubuntu/cuda9/Dockerfile index 98a0a88dfa..a5892cc910 100644 --- a/tool/docker/devel/ubuntu/cuda9/Dockerfile +++ b/tool/docker/devel/ubuntu/cuda9/Dockerfile @@ -58,14 +58,14 @@ RUN apt-get update \ # install swig > 3.0.10 RUN wget http://prdownloads.sourceforge.net/swig/swig-3.0.10.tar.gz -P /tmp/ \ && tar zxf /tmp/swig-3.0.10.tar.gz -C /tmp/ \ - && cd /tmp/swig-3.0.10 && ./configure && make && make install + && cd /tmp/swig-3.0.10 && ./configure && make -j4 && make install # install mkldnn RUN wget https://github.com/intel/mkl-dnn/archive/v0.18.tar.gz -P /tmp/ \ && tar zxf /tmp/v0.18.tar.gz -C /tmp/ \ && cd /tmp/mkl-dnn-0.18/ \ && mkdir -p build && cd build && cmake .. \ - && make && make install + && make -j4 && make install # config ssh service RUN mkdir /var/run/sshd \ @@ -77,6 +77,32 @@ RUN mkdir /var/run/sshd \ # dump environment variables into files, so that ssh can see also && env | grep _ >> /etc/environment +# start of TC +RUN apt-get update && apt-get install -y unzip libgmp3-dev automake curl + +RUN wget https://repo.anaconda.com/archive/Anaconda3-5.1.0-Linux-x86_64.sh -O anaconda.sh && \ + chmod +x anaconda.sh && \ + ./anaconda.sh -b -p /root/conda && \ + rm anaconda.sh + +RUN . /root/conda/bin/activate && \ + conda update -y -n base conda && \ + conda create -y --name tc_build python=3.6 && \ + conda activate tc_build && \ + conda install -y pyyaml mkl-include pytest && \ + conda install -y -c nicolasvasilache llvm-trunk halide && \ + conda install -y -c pytorch pytorch=0.4.0 torchvision cuda90 && \ + conda remove -y cudatoolkit --force + + +RUN cd /root/ && git clone https://github.com/dcslin/TensorComprehensions.git --recursive + +ENV CUDA_TOOLKIT_ROOT_DIR /usr/local/cuda + +ENV CLANG_PREFIX /root/conda/envs/tc_build +RUN cd /root/TensorComprehensions && . /root/conda/bin/activate && conda activate tc_build && ./build.sh +## end of TC + # build incubator singa RUN git clone https://github.com/apache/incubator-singa.git $HOME/incubator-singa \ && cd $HOME/incubator-singa \