From 797f6227b7dd2523808b503b76814329f9de9ab9 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 30 Mar 2022 18:08:12 +0200 Subject: [PATCH 01/14] Exporting model to C++ (wip) --- cpp/.gitignore | 10 + cpp/CMakeLists.txt | 36 ++ cpp/assets/.gitkeep | 0 cpp/cmake/CMakeRC.cmake | 644 ++++++++++++++++++++ cpp/model_template.cpp | 22 + cpp/model_template.h | 14 + cpp/src/baselines3_models/predictor.cpp | 39 ++ cpp/src/baselines3_models/preprocessing.cpp | 24 + cpp/src/predict.cpp | 18 + enjoy.py | 14 +- 10 files changed, 820 insertions(+), 1 deletion(-) create mode 100644 cpp/.gitignore create mode 100644 cpp/CMakeLists.txt create mode 100644 cpp/assets/.gitkeep create mode 100644 cpp/cmake/CMakeRC.cmake create mode 100644 cpp/model_template.cpp create mode 100644 cpp/model_template.h create mode 100644 cpp/src/baselines3_models/predictor.cpp create mode 100644 cpp/src/baselines3_models/preprocessing.cpp create mode 100644 cpp/src/predict.cpp diff --git a/cpp/.gitignore b/cpp/.gitignore new file mode 100644 index 000000000..7b99f6aaa --- /dev/null +++ b/cpp/.gitignore @@ -0,0 +1,10 @@ +assets/* +!assets/.gitkeep + +include/baselines3_models/* +!include/baselines3_models/predictor.h +!include/baselines3_models/preprocessing.h + +src/baselines3_models/* +!src/baselines3_models/predictor.cpp +!src/baselines3_models/preprocessing.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt new file mode 100644 index 000000000..17b6fd89f --- /dev/null +++ b/cpp/CMakeLists.txt @@ -0,0 +1,36 @@ +cmake_minimum_required(VERSION 3.16.3) +project(baselines3_models) +include(cmake/CMakeRC.cmake) + +cmrc_add_resource_library(baselines3_model_resources + ALIAS baselines3_model::rc + NAMESPACE baselines3_model + assets/approach_v0_model.pt + ) + +# Install PyTorch C++ first, see: https://pytorch.org/cppdocs/installing.html +# Don't forget to add it to your CMAKE_PREFIX_PATH +find_package(Torch REQUIRED) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + +#Enable C++17 +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -fPIC") + +option(BASELINES3_BIN "Building bin" OFF) + +set(ALL_SOURCES + src/baselines3_models/predictor.cpp + src/baselines3_models/preprocessing.cpp + src/baselines3_models/approach_v0.cpp +) + +add_library(baselines3_models SHARED ${ALL_SOURCES}) +target_link_libraries(baselines3_models "${TORCH_LIBRARIES}" baselines3_model::rc) +target_include_directories(baselines3_models PUBLIC + $ +) + +if (BASELINES3_BIN) + add_executable(predict ${CMAKE_CURRENT_SOURCE_DIR}/src/predict.cpp) + target_link_libraries(predict baselines3_models) +endif() \ No newline at end of file diff --git a/cpp/assets/.gitkeep b/cpp/assets/.gitkeep new file mode 100644 index 000000000..e69de29bb diff --git a/cpp/cmake/CMakeRC.cmake b/cpp/cmake/CMakeRC.cmake new file mode 100644 index 000000000..1a034b5f1 --- /dev/null +++ b/cpp/cmake/CMakeRC.cmake @@ -0,0 +1,644 @@ +# This block is executed when generating an intermediate resource file, not when +# running in CMake configure mode +if(_CMRC_GENERATE_MODE) + # Read in the digits + file(READ "${INPUT_FILE}" bytes HEX) + # Format each pair into a character literal. Heuristics seem to favor doing + # the conversion in groups of five for fastest conversion + string(REGEX REPLACE "(..)(..)(..)(..)(..)" "'\\\\x\\1','\\\\x\\2','\\\\x\\3','\\\\x\\4','\\\\x\\5'," chars "${bytes}") + # Since we did this in groups, we have some leftovers to clean up + string(LENGTH "${bytes}" n_bytes2) + math(EXPR n_bytes "${n_bytes2} / 2") + math(EXPR remainder "${n_bytes} % 5") # <-- '5' is the grouping count from above + set(cleanup_re "$") + set(cleanup_sub ) + while(remainder) + set(cleanup_re "(..)${cleanup_re}") + set(cleanup_sub "'\\\\x\\${remainder}',${cleanup_sub}") + math(EXPR remainder "${remainder} - 1") + endwhile() + if(NOT cleanup_re STREQUAL "$") + string(REGEX REPLACE "${cleanup_re}" "${cleanup_sub}" chars "${chars}") + endif() + string(CONFIGURE [[ + namespace { const char file_array[] = { @chars@ 0 }; } + namespace cmrc { namespace @NAMESPACE@ { namespace res_chars { + extern const char* const @SYMBOL@_begin = file_array; + extern const char* const @SYMBOL@_end = file_array + @n_bytes@; + }}} + ]] code) + file(WRITE "${OUTPUT_FILE}" "${code}") + # Exit from the script. Nothing else needs to be processed + return() +endif() + +set(_version 2.0.0) + +cmake_minimum_required(VERSION 3.3) +include(CMakeParseArguments) + +if(COMMAND cmrc_add_resource_library) + if(NOT DEFINED _CMRC_VERSION OR NOT (_version STREQUAL _CMRC_VERSION)) + message(WARNING "More than one CMakeRC version has been included in this project.") + endif() + # CMakeRC has already been included! Don't do anything + return() +endif() + +set(_CMRC_VERSION "${_version}" CACHE INTERNAL "CMakeRC version. Used for checking for conflicts") + +set(_CMRC_SCRIPT "${CMAKE_CURRENT_LIST_FILE}" CACHE INTERNAL "Path to CMakeRC script") + +function(_cmrc_normalize_path var) + set(path "${${var}}") + file(TO_CMAKE_PATH "${path}" path) + while(path MATCHES "//") + string(REPLACE "//" "/" path "${path}") + endwhile() + string(REGEX REPLACE "/+$" "" path "${path}") + set("${var}" "${path}" PARENT_SCOPE) +endfunction() + +get_filename_component(_inc_dir "${CMAKE_BINARY_DIR}/_cmrc/include" ABSOLUTE) +set(CMRC_INCLUDE_DIR "${_inc_dir}" CACHE INTERNAL "Directory for CMakeRC include files") +# Let's generate the primary include file +file(MAKE_DIRECTORY "${CMRC_INCLUDE_DIR}/cmrc") +set(hpp_content [==[ +#ifndef CMRC_CMRC_HPP_INCLUDED +#define CMRC_CMRC_HPP_INCLUDED + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if !(defined(__EXCEPTIONS) || defined(__cpp_exceptions) || defined(_CPPUNWIND) || defined(CMRC_NO_EXCEPTIONS)) +#define CMRC_NO_EXCEPTIONS 1 +#endif + +namespace cmrc { namespace detail { struct dummy; } } + +#define CMRC_DECLARE(libid) \ + namespace cmrc { namespace detail { \ + struct dummy; \ + static_assert(std::is_same::value, "CMRC_DECLARE() must only appear at the global namespace"); \ + } } \ + namespace cmrc { namespace libid { \ + cmrc::embedded_filesystem get_filesystem(); \ + } } static_assert(true, "") + +namespace cmrc { + +class file { + const char* _begin = nullptr; + const char* _end = nullptr; + +public: + using iterator = const char*; + using const_iterator = iterator; + iterator begin() const noexcept { return _begin; } + iterator cbegin() const noexcept { return _begin; } + iterator end() const noexcept { return _end; } + iterator cend() const noexcept { return _end; } + std::size_t size() const { return static_cast(std::distance(begin(), end())); } + + file() = default; + file(iterator beg, iterator end) noexcept : _begin(beg), _end(end) {} +}; + +class directory_entry; + +namespace detail { + +class directory; +class file_data; + +class file_or_directory { + union _data_t { + class file_data* file_data; + class directory* directory; + } _data; + bool _is_file = true; + +public: + explicit file_or_directory(file_data& f) { + _data.file_data = &f; + } + explicit file_or_directory(directory& d) { + _data.directory = &d; + _is_file = false; + } + bool is_file() const noexcept { + return _is_file; + } + bool is_directory() const noexcept { + return !is_file(); + } + const directory& as_directory() const noexcept { + assert(!is_file()); + return *_data.directory; + } + const file_data& as_file() const noexcept { + assert(is_file()); + return *_data.file_data; + } +}; + +class file_data { +public: + const char* begin_ptr; + const char* end_ptr; + file_data(const file_data&) = delete; + file_data(const char* b, const char* e) : begin_ptr(b), end_ptr(e) {} +}; + +inline std::pair split_path(const std::string& path) { + auto first_sep = path.find("/"); + if (first_sep == path.npos) { + return std::make_pair(path, ""); + } else { + return std::make_pair(path.substr(0, first_sep), path.substr(first_sep + 1)); + } +} + +struct created_subdirectory { + class directory& directory; + class file_or_directory& index_entry; +}; + +class directory { + std::list _files; + std::list _dirs; + std::map _index; + + using base_iterator = std::map::const_iterator; + +public: + + directory() = default; + directory(const directory&) = delete; + + created_subdirectory add_subdir(std::string name) & { + _dirs.emplace_back(); + auto& back = _dirs.back(); + auto& fod = _index.emplace(name, file_or_directory{back}).first->second; + return created_subdirectory{back, fod}; + } + + file_or_directory* add_file(std::string name, const char* begin, const char* end) & { + assert(_index.find(name) == _index.end()); + _files.emplace_back(begin, end); + return &_index.emplace(name, file_or_directory{_files.back()}).first->second; + } + + const file_or_directory* get(const std::string& path) const { + auto pair = split_path(path); + auto child = _index.find(pair.first); + if (child == _index.end()) { + return nullptr; + } + auto& entry = child->second; + if (pair.second.empty()) { + // We're at the end of the path + return &entry; + } + + if (entry.is_file()) { + // We can't traverse into a file. Stop. + return nullptr; + } + // Keep going down + return entry.as_directory().get(pair.second); + } + + class iterator { + base_iterator _base_iter; + base_iterator _end_iter; + public: + using value_type = directory_entry; + using difference_type = std::ptrdiff_t; + using pointer = const value_type*; + using reference = const value_type&; + using iterator_category = std::input_iterator_tag; + + iterator() = default; + explicit iterator(base_iterator iter, base_iterator end) : _base_iter(iter), _end_iter(end) {} + + iterator begin() const noexcept { + return *this; + } + + iterator end() const noexcept { + return iterator(_end_iter, _end_iter); + } + + inline value_type operator*() const noexcept; + + bool operator==(const iterator& rhs) const noexcept { + return _base_iter == rhs._base_iter; + } + + bool operator!=(const iterator& rhs) const noexcept { + return !(*this == rhs); + } + + iterator operator++() noexcept { + auto cp = *this; + ++_base_iter; + return cp; + } + + iterator& operator++(int) noexcept { + ++_base_iter; + return *this; + } + }; + + using const_iterator = iterator; + + iterator begin() const noexcept { + return iterator(_index.begin(), _index.end()); + } + + iterator end() const noexcept { + return iterator(); + } +}; + +inline std::string normalize_path(std::string path) { + while (path.find("/") == 0) { + path.erase(path.begin()); + } + while (!path.empty() && (path.rfind("/") == path.size() - 1)) { + path.pop_back(); + } + auto off = path.npos; + while ((off = path.find("//")) != path.npos) { + path.erase(path.begin() + static_cast(off)); + } + return path; +} + +using index_type = std::map; + +} // detail + +class directory_entry { + std::string _fname; + const detail::file_or_directory* _item; + +public: + directory_entry() = delete; + explicit directory_entry(std::string filename, const detail::file_or_directory& item) + : _fname(filename) + , _item(&item) + {} + + const std::string& filename() const & { + return _fname; + } + std::string filename() const && { + return std::move(_fname); + } + + bool is_file() const { + return _item->is_file(); + } + + bool is_directory() const { + return _item->is_directory(); + } +}; + +directory_entry detail::directory::iterator::operator*() const noexcept { + assert(begin() != end()); + return directory_entry(_base_iter->first, _base_iter->second); +} + +using directory_iterator = detail::directory::iterator; + +class embedded_filesystem { + // Never-null: + const cmrc::detail::index_type* _index; + const detail::file_or_directory* _get(std::string path) const { + path = detail::normalize_path(path); + auto found = _index->find(path); + if (found == _index->end()) { + return nullptr; + } else { + return found->second; + } + } + +public: + explicit embedded_filesystem(const detail::index_type& index) + : _index(&index) + {} + + file open(const std::string& path) const { + auto entry_ptr = _get(path); + if (!entry_ptr || !entry_ptr->is_file()) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); +#endif + } + auto& dat = entry_ptr->as_file(); + return file{dat.begin_ptr, dat.end_ptr}; + } + + bool is_file(const std::string& path) const noexcept { + auto entry_ptr = _get(path); + return entry_ptr && entry_ptr->is_file(); + } + + bool is_directory(const std::string& path) const noexcept { + auto entry_ptr = _get(path); + return entry_ptr && entry_ptr->is_directory(); + } + + bool exists(const std::string& path) const noexcept { + return !!_get(path); + } + + directory_iterator iterate_directory(const std::string& path) const { + auto entry_ptr = _get(path); + if (!entry_ptr) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error no such file or directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::no_such_file_or_directory), path); +#endif + } + if (!entry_ptr->is_directory()) { +#ifdef CMRC_NO_EXCEPTIONS + fprintf(stderr, "Error not a directory: %s\n", path.c_str()); + abort(); +#else + throw std::system_error(make_error_code(std::errc::not_a_directory), path); +#endif + } + return entry_ptr->as_directory().begin(); + } +}; + +} + +#endif // CMRC_CMRC_HPP_INCLUDED +]==]) + +set(cmrc_hpp "${CMRC_INCLUDE_DIR}/cmrc/cmrc.hpp" CACHE INTERNAL "") +set(_generate 1) +if(EXISTS "${cmrc_hpp}") + file(READ "${cmrc_hpp}" _current) + if(_current STREQUAL hpp_content) + set(_generate 0) + endif() +endif() +file(GENERATE OUTPUT "${cmrc_hpp}" CONTENT "${hpp_content}" CONDITION ${_generate}) + +add_library(cmrc-base INTERFACE) +target_include_directories(cmrc-base INTERFACE $) +# Signal a basic C++11 feature to require C++11. +target_compile_features(cmrc-base INTERFACE cxx_nullptr) +set_property(TARGET cmrc-base PROPERTY INTERFACE_CXX_EXTENSIONS OFF) +add_library(cmrc::base ALIAS cmrc-base) + +function(cmrc_add_resource_library name) + set(args ALIAS NAMESPACE TYPE) + cmake_parse_arguments(ARG "" "${args}" "" "${ARGN}") + # Generate the identifier for the resource library's namespace + set(ns_re "[a-zA-Z_][a-zA-Z0-9_]*") + if(NOT DEFINED ARG_NAMESPACE) + # Check that the library name is also a valid namespace + if(NOT name MATCHES "${ns_re}") + message(SEND_ERROR "Library name is not a valid namespace. Specify the NAMESPACE argument") + endif() + set(ARG_NAMESPACE "${name}") + else() + if(NOT ARG_NAMESPACE MATCHES "${ns_re}") + message(SEND_ERROR "NAMESPACE for ${name} is not a valid C++ namespace identifier (${ARG_NAMESPACE})") + endif() + endif() + set(libname "${name}") + # Check that type is either "STATIC" or "OBJECT", or default to "STATIC" if + # not set + if(NOT DEFINED ARG_TYPE) + set(ARG_TYPE STATIC) + elseif(NOT "${ARG_TYPE}" MATCHES "^(STATIC|OBJECT)$") + message(SEND_ERROR "${ARG_TYPE} is not a valid TYPE (STATIC and OBJECT are acceptable)") + set(ARG_TYPE STATIC) + endif() + # Generate a library with the compiled in character arrays. + string(CONFIGURE [=[ + #include + #include + #include + + namespace cmrc { + namespace @ARG_NAMESPACE@ { + + namespace res_chars { + // These are the files which are available in this resource library + $, + > + } + + namespace { + + const cmrc::detail::index_type& + get_root_index() { + static cmrc::detail::directory root_directory_; + static cmrc::detail::file_or_directory root_directory_fod{root_directory_}; + static cmrc::detail::index_type root_index; + root_index.emplace("", &root_directory_fod); + struct dir_inl { + class cmrc::detail::directory& directory; + }; + dir_inl root_directory_dir{root_directory_}; + (void)root_directory_dir; + $, + > + $, + > + return root_index; + } + + } + + cmrc::embedded_filesystem get_filesystem() { + static auto& index = get_root_index(); + return cmrc::embedded_filesystem{index}; + } + + } // @ARG_NAMESPACE@ + } // cmrc + ]=] cpp_content @ONLY) + get_filename_component(libdir "${CMAKE_CURRENT_BINARY_DIR}/__cmrc_${name}" ABSOLUTE) + get_filename_component(lib_tmp_cpp "${libdir}/lib_.cpp" ABSOLUTE) + string(REPLACE "\n " "\n" cpp_content "${cpp_content}") + file(GENERATE OUTPUT "${lib_tmp_cpp}" CONTENT "${cpp_content}") + get_filename_component(libcpp "${libdir}/lib.cpp" ABSOLUTE) + add_custom_command(OUTPUT "${libcpp}" + DEPENDS "${lib_tmp_cpp}" "${cmrc_hpp}" + COMMAND ${CMAKE_COMMAND} -E copy_if_different "${lib_tmp_cpp}" "${libcpp}" + COMMENT "Generating ${name} resource loader" + ) + # Generate the actual static library. Each source file is just a single file + # with a character array compiled in containing the contents of the + # corresponding resource file. + add_library(${name} ${ARG_TYPE} ${libcpp}) + set_property(TARGET ${name} PROPERTY CMRC_LIBDIR "${libdir}") + set_property(TARGET ${name} PROPERTY CMRC_NAMESPACE "${ARG_NAMESPACE}") + target_link_libraries(${name} PUBLIC cmrc::base) + set_property(TARGET ${name} PROPERTY CMRC_IS_RESOURCE_LIBRARY TRUE) + if(ARG_ALIAS) + add_library("${ARG_ALIAS}" ALIAS ${name}) + endif() + cmrc_add_resources(${name} ${ARG_UNPARSED_ARGUMENTS}) +endfunction() + +function(_cmrc_register_dirs name dirpath) + if(dirpath STREQUAL "") + return() + endif() + # Skip this dir if we have already registered it + get_target_property(registered "${name}" _CMRC_REGISTERED_DIRS) + if(dirpath IN_LIST registered) + return() + endif() + # Register the parent directory first + get_filename_component(parent "${dirpath}" DIRECTORY) + if(NOT parent STREQUAL "") + _cmrc_register_dirs("${name}" "${parent}") + endif() + # Now generate the registration + set_property(TARGET "${name}" APPEND PROPERTY _CMRC_REGISTERED_DIRS "${dirpath}") + _cm_encode_fpath(sym "${dirpath}") + if(parent STREQUAL "") + set(parent_sym root_directory) + else() + _cm_encode_fpath(parent_sym "${parent}") + endif() + get_filename_component(leaf "${dirpath}" NAME) + set_property( + TARGET "${name}" + APPEND PROPERTY CMRC_MAKE_DIRS + "static auto ${sym}_dir = ${parent_sym}_dir.directory.add_subdir(\"${leaf}\")\;" + "root_index.emplace(\"${dirpath}\", &${sym}_dir.index_entry)\;" + ) +endfunction() + +function(cmrc_add_resources name) + get_target_property(is_reslib ${name} CMRC_IS_RESOURCE_LIBRARY) + if(NOT TARGET ${name} OR NOT is_reslib) + message(SEND_ERROR "cmrc_add_resources called on target '${name}' which is not an existing resource library") + return() + endif() + + set(options) + set(args WHENCE PREFIX) + set(list_args) + cmake_parse_arguments(ARG "${options}" "${args}" "${list_args}" "${ARGN}") + + if(NOT ARG_WHENCE) + set(ARG_WHENCE ${CMAKE_CURRENT_SOURCE_DIR}) + endif() + _cmrc_normalize_path(ARG_WHENCE) + get_filename_component(ARG_WHENCE "${ARG_WHENCE}" ABSOLUTE) + + # Generate the identifier for the resource library's namespace + get_target_property(lib_ns "${name}" CMRC_NAMESPACE) + + get_target_property(libdir ${name} CMRC_LIBDIR) + get_target_property(target_dir ${name} SOURCE_DIR) + file(RELATIVE_PATH reldir "${target_dir}" "${CMAKE_CURRENT_SOURCE_DIR}") + if(reldir MATCHES "^\\.\\.") + message(SEND_ERROR "Cannot call cmrc_add_resources in a parent directory from the resource library target") + return() + endif() + + foreach(input IN LISTS ARG_UNPARSED_ARGUMENTS) + _cmrc_normalize_path(input) + get_filename_component(abs_in "${input}" ABSOLUTE) + # Generate a filename based on the input filename that we can put in + # the intermediate directory. + file(RELATIVE_PATH relpath "${ARG_WHENCE}" "${abs_in}") + if(relpath MATCHES "^\\.\\.") + # For now we just error on files that exist outside of the soure dir. + message(SEND_ERROR "Cannot add file '${input}': File must be in a subdirectory of ${ARG_WHENCE}") + continue() + endif() + if(DEFINED ARG_PREFIX) + _cmrc_normalize_path(ARG_PREFIX) + endif() + if(ARG_PREFIX AND NOT ARG_PREFIX MATCHES "/$") + set(ARG_PREFIX "${ARG_PREFIX}/") + endif() + get_filename_component(dirpath "${ARG_PREFIX}${relpath}" DIRECTORY) + _cmrc_register_dirs("${name}" "${dirpath}") + get_filename_component(abs_out "${libdir}/intermediate/${relpath}.cpp" ABSOLUTE) + # Generate a symbol name relpath the file's character array + _cm_encode_fpath(sym "${relpath}") + # Get the symbol name for the parent directory + if(dirpath STREQUAL "") + set(parent_sym root_directory) + else() + _cm_encode_fpath(parent_sym "${dirpath}") + endif() + # Generate the rule for the intermediate source file + _cmrc_generate_intermediate_cpp(${lib_ns} ${sym} "${abs_out}" "${abs_in}") + target_sources(${name} PRIVATE "${abs_out}") + set_property(TARGET ${name} APPEND PROPERTY CMRC_EXTERN_DECLS + "// Pointers to ${input}" + "extern const char* const ${sym}_begin\;" + "extern const char* const ${sym}_end\;" + ) + get_filename_component(leaf "${relpath}" NAME) + set_property( + TARGET ${name} + APPEND PROPERTY CMRC_MAKE_FILES + "root_index.emplace(" + " \"${ARG_PREFIX}${relpath}\"," + " ${parent_sym}_dir.directory.add_file(" + " \"${leaf}\"," + " res_chars::${sym}_begin," + " res_chars::${sym}_end" + " )" + ")\;" + ) + endforeach() +endfunction() + +function(_cmrc_generate_intermediate_cpp lib_ns symbol outfile infile) + add_custom_command( + # This is the file we will generate + OUTPUT "${outfile}" + # These are the primary files that affect the output + DEPENDS "${infile}" "${_CMRC_SCRIPT}" + COMMAND + "${CMAKE_COMMAND}" + -D_CMRC_GENERATE_MODE=TRUE + -DNAMESPACE=${lib_ns} + -DSYMBOL=${symbol} + "-DINPUT_FILE=${infile}" + "-DOUTPUT_FILE=${outfile}" + -P "${_CMRC_SCRIPT}" + COMMENT "Generating intermediate file for ${infile}" + ) +endfunction() + +function(_cm_encode_fpath var fpath) + string(MAKE_C_IDENTIFIER "${fpath}" ident) + string(MD5 hash "${fpath}") + string(SUBSTRING "${hash}" 0 4 hash) + set(${var} f_${hash}_${ident} PARENT_SCOPE) +endfunction() diff --git a/cpp/model_template.cpp b/cpp/model_template.cpp new file mode 100644 index 000000000..42855a62f --- /dev/null +++ b/cpp/model_template.cpp @@ -0,0 +1,22 @@ +#include "baselines3_models/FILE_NAME.h" +#include "baselines3_models/preprocessing.h" + +namespace baselines3_models { + +CLASS_NAME::CLASS_NAME() : Predictor("MODEL_FNAME") { + policy_type = POLICY_TYPE; +} + +torch::Tensor CLASS_NAME::preprocess_observation(torch::Tensor &observation) { + torch::Tensor result; + PREPROCESS_OBSERVATION + return result; +} + +torch::Tensor CLASS_NAME::process_action(torch::Tensor &action) { + torch::Tensor result; + PROCESS_ACTION + return result; +} + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/model_template.h b/cpp/model_template.h new file mode 100644 index 000000000..d7d8f884a --- /dev/null +++ b/cpp/model_template.h @@ -0,0 +1,14 @@ +#pragma once + +#include "baselines3_models/predictor.h" +#include "torch/script.h" + +namespace baselines3_models { +class CLASS_NAME : public Predictor { +public: + CLASS_NAME(); + + torch::Tensor preprocess_observation(torch::Tensor &observation) override; + torch::Tensor process_action(torch::Tensor &action) override; +}; +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp new file mode 100644 index 000000000..2c1c91bd8 --- /dev/null +++ b/cpp/src/baselines3_models/predictor.cpp @@ -0,0 +1,39 @@ +#include "baselines3_models/predictor.h" +#include "cmrc/cmrc.hpp" + +CMRC_DECLARE(baselines3_model); + +namespace baselines3_models { +Predictor::Predictor(std::string model_filename) { + auto fs = cmrc::baselines3_model::get_filesystem(); + auto f = fs.open(model_filename); + std::string data(f.begin(), f.end()); + std::istringstream stream(data); + module = torch::jit::load(stream); +} + +torch::Tensor Predictor::predict(torch::Tensor &observation) { + c10::InferenceMode guard; + torch::Tensor processed_observation = preprocess_observation(observation); + + if (policy_type == ACTOR_DETERMINISTIC) { + std::vector inputs; + inputs.push_back(processed_observation); + + at::Tensor action = module.forward(inputs).toTensor(); + } else { + throw std::runtime_error("Unknown policy type"); + } + + return process_action(action); +} + +torch::Tensor Predictor::preprocess_observation(torch::Tensor &observation) { + return observation; +} + +torch::Tensor Predictor::process_action(torch::Tensor &action) { + return action; +} + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/baselines3_models/preprocessing.cpp b/cpp/src/baselines3_models/preprocessing.cpp new file mode 100644 index 000000000..e0707e450 --- /dev/null +++ b/cpp/src/baselines3_models/preprocessing.cpp @@ -0,0 +1,24 @@ +#include "baselines3_models/preprocessing.h" + +using namespace torch::indexing; + +namespace baselines3_models { + +torch::Tensor multi_one_hot(torch::Tensor &input, torch::Tensor &classes) { + int entries = torch::sum(classes).item(); + + torch::Tensor result = + torch::zeros({1, entries}, torch::TensorOptions().dtype(torch::kLong)); + + int offset = 0; + for (int k = 0; k < classes.sizes()[0]; k++) { + int n = classes[k].item(); + + result.index({0, Slice(offset, offset + n)}) = torch::one_hot(input[k], n); + offset += n; + } + + return result; +} + +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/predict.cpp b/cpp/src/predict.cpp new file mode 100644 index 000000000..326039a9e --- /dev/null +++ b/cpp/src/predict.cpp @@ -0,0 +1,18 @@ +#include "baselines3_models/preprocessing.h" +#include "baselines3_models/approach_v0.h" +#include +#include +#include +#include "cmrc/cmrc.hpp" + +using namespace baselines3_models; +using namespace torch::indexing; + +int main(int argc, const char *argv[]) { + approach_v0 approach; + + torch::Tensor observation = torch::tensor({-1., 0., 0., 1., 0., 1., 0., 0., 0.}); + torch::Tensor action = approach.predict(observation); + + std::cout << (action) << std::endl; +} \ No newline at end of file diff --git a/enjoy.py b/enjoy.py index c92304fb2..773aad4c4 100644 --- a/enjoy.py +++ b/enjoy.py @@ -12,6 +12,7 @@ import utils.import_envs # noqa: F401 pylint: disable=unused-import from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams from utils.exp_manager import ExperimentManager +from utils.cpp_exporter import CppExporter from utils.utils import StoreDict @@ -30,6 +31,7 @@ def main(): # noqa: C901 ) parser.add_argument("--deterministic", action="store_true", default=False, help="Use deterministic actions") parser.add_argument("--device", help="PyTorch device to be use (ex: cpu, cuda...)", default="auto", type=str) + parser.add_argument("--export-cpp", help="Export to C++ code", default="", type=str) parser.add_argument( "--load-best", action="store_true", default=False, help="Load best model instead of last model if available" ) @@ -70,6 +72,10 @@ def main(): # noqa: C901 env_id = args.env algo = args.algo folder = args.folder + device = args.device + + if args.export_cpp: + device = "cpu" if args.exp_id == 0: args.exp_id = get_latest_run_id(os.path.join(folder, algo), env_id) @@ -176,10 +182,16 @@ def step_count(checkpoint_path: str) -> int: "clip_range": lambda _: 0.0, } - model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=args.device, **kwargs) + model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, device=device, **kwargs) obs = env.reset() + if args.export_cpp: + print("Exporting to C++...") + exporter = CppExporter(model, args.export_cpp, args.env) + exporter.export() + exit() + # Deterministic by default except for atari games stochastic = args.stochastic or is_atari and not args.deterministic deterministic = not stochastic From a309c43c93d7867aa5766c6f9b339cfc6f719bc6 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 30 Mar 2022 18:43:07 +0200 Subject: [PATCH 02/14] CPP Export (wip) --- cpp/include/baselines3_models/predictor.h | 24 ++++ cpp/include/baselines3_models/preprocessing.h | 10 ++ cpp/src/baselines3_models/predictor.cpp | 5 +- utils/cpp_exporter.py | 105 ++++++++++++++++++ 4 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 cpp/include/baselines3_models/predictor.h create mode 100644 cpp/include/baselines3_models/preprocessing.h create mode 100644 utils/cpp_exporter.py diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h new file mode 100644 index 000000000..a8667fcaf --- /dev/null +++ b/cpp/include/baselines3_models/predictor.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +namespace baselines3_models { +class Predictor { +public: + enum PolicyType { + ACTOR_MU + }; + + Predictor(std::string model_filename); + + torch::Tensor predict(torch::Tensor &observation); + + virtual torch::Tensor preprocess_observation(torch::Tensor &observation); + virtual torch::Tensor process_action(torch::Tensor &action); + +protected: + torch::jit::script::Module module; + PolicyType policy_type; +}; +} // namespace baselines3_models \ No newline at end of file diff --git a/cpp/include/baselines3_models/preprocessing.h b/cpp/include/baselines3_models/preprocessing.h new file mode 100644 index 000000000..4dfcc098a --- /dev/null +++ b/cpp/include/baselines3_models/preprocessing.h @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +namespace baselines3_models { + +torch::Tensor multi_one_hot(torch::Tensor &input, torch::Tensor &classes); + +} \ No newline at end of file diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index 2c1c91bd8..5c094d280 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -15,12 +15,13 @@ Predictor::Predictor(std::string model_filename) { torch::Tensor Predictor::predict(torch::Tensor &observation) { c10::InferenceMode guard; torch::Tensor processed_observation = preprocess_observation(observation); + at::Tensor action; - if (policy_type == ACTOR_DETERMINISTIC) { + if (policy_type == ACTOR_MU) { std::vector inputs; inputs.push_back(processed_observation); - at::Tensor action = module.forward(inputs).toTensor(); + action = module.forward(inputs).toTensor(); } else { throw std::runtime_error("Unknown policy type"); } diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py new file mode 100644 index 000000000..6b817dc66 --- /dev/null +++ b/utils/cpp_exporter.py @@ -0,0 +1,105 @@ +import torch as th +from gym import spaces +from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.td3.policies import TD3Policy +from stable_baselines3.sac.policies import SACPolicy + + +class CppExporter(object): + def __init__(self, model: BaseAlgorithm, directory: str, name: str): + self.model = model + self.directory = directory + self.name = name.replace("-", "_") + self.template_directory = dir = "/".join(__file__.split("/")[:-1] + ["..", "cpp"]) + self.vars = {} + + def generate_observation_preprocessing(self): + observation_space = self.model.env.observation_space + policy = self.model.policy + preprocess_observation = "" + + if isinstance(observation_space, spaces.Box): + if is_image_space(observation_space) and policy.normalize_images: + preprocess_observation += "result = observation / 255.;\n" + else: + preprocess_observation += "result = observation;\n" + elif isinstance(observation_space, spaces.Discrete): + preprocess_observation += f"result = torch::one_hot(observation, {observation_space.n});\n" + elif isinstance(observation_space, spaces.MultiDiscrete): + classes = ",".join(map(str, observation_space.nvec)) + preprocess_observation += "torch::Tensor classes = torch::tensor({%s});\n" % classes + preprocess_observation += f"result = multi_one_hot(observation, classes);\n" + else: + raise NotImplementedError(f"C++ exporting does not support observation {observation_space}") + + self.vars["PREPROCESS_OBSERVATION"] = preprocess_observation + + def generate_action_processing(self): + action_space = self.model.env.action_space + + if isinstance(action_space, spaces.Box): + if self.model.policy.squash_output: + low_values = ",".join([f"(float){x}" for x in action_space.low]) + process_action = "torch::Tensor action_low = torch::tensor({%s});\n" % low_values + high_values = ",".join([f"(float){x}" for x in action_space.high]) + process_action += "torch::Tensor action_high = torch::tensor({%s});\n" % high_values + + process_action += "result = action_low + (0.5 * (action + 1.0) * (action_high - action_low));\n" + else: + process_action = "result = action;\n" + elif isinstance(action_space, spaces.Box): + process_action = "result = action;\n" + else: + raise NotImplementedError(f"C++ exporting does not support processing action {action_space}") + + self.vars["PROCESS_ACTION"] = process_action + + def export_code(self): + self.vars["CLASS_NAME"] = self.name + fname = self.name.lower() + self.vars["FILE_NAME"] = fname + target_header = self.directory + f"/include/baselines3_models/{fname}.h" + target_cpp = self.directory + f"/src/baselines3_models/{fname}.cpp" + + self.generate_observation_preprocessing() + self.generate_action_processing() + + self.render("model_template.h", target_header) + self.render("model_template.cpp", target_cpp) + + def render(self, template: str, target: str): + with open(self.template_directory + "/" + template, "r") as template_f: + with open(target, "w") as target_f: + data = template_f.read() + for var in self.vars: + data = data.replace(var, self.vars[var]) + target_f.write(data) + + print("Generated " + target) + + def export_model(self): + policy = self.model.policy + obs = th.Tensor(self.model.env.reset()) + asset_fname = f"assets/{self.name}_model.pt" + fname = self.directory + "/" + asset_fname + traced_script_module = None + + if isinstance(policy, TD3Policy): + traced_script_module = th.jit.trace(policy.actor.mu, policy.actor.extract_features(obs)) + self.vars["POLICY_TYPE"] = "ACTOR_MU" + if isinstance(policy, ActorCriticPolicy) or isinstance(policy, SACPolicy): + model = th.nn.Sequential(policy.actor.latent_pi, policy.actor.mu) + traced_script_module = th.jit.trace(model, policy.actor.extract_features(obs)) + self.vars["POLICY_TYPE"] = "ACTOR_MU" + else: + raise NotImplementedError(f"C++ exporting does not support policy {policy}") + + print(f"Generated {fname}") + traced_script_module.save(fname) + self.vars["MODEL_FNAME"] = asset_fname + + def export(self): + self.export_model() + self.export_code() From 6b708abca6d0838aa36538242e48259956380214 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 30 Mar 2022 19:18:44 +0200 Subject: [PATCH 03/14] CPP Export (wip) --- cpp/CMakeLists.txt | 8 ++++++- cpp/include/baselines3_models/predictor.h | 4 +++- cpp/model_template.h | 9 ++++++++ cpp/src/baselines3_models/predictor.cpp | 14 +++++++++++- cpp/src/predict.cpp | 16 +++++++++---- utils/cpp_exporter.py | 28 +++++++++++++++++------ 6 files changed, 64 insertions(+), 15 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 17b6fd89f..cafc2fd36 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -5,7 +5,10 @@ include(cmake/CMakeRC.cmake) cmrc_add_resource_library(baselines3_model_resources ALIAS baselines3_model::rc NAMESPACE baselines3_model + + # XXX: Generate assets/approach_v0_model.pt + assets/CartPole_v1_model.pt ) # Install PyTorch C++ first, see: https://pytorch.org/cppdocs/installing.html @@ -21,7 +24,10 @@ option(BASELINES3_BIN "Building bin" OFF) set(ALL_SOURCES src/baselines3_models/predictor.cpp src/baselines3_models/preprocessing.cpp - src/baselines3_models/approach_v0.cpp + + # XXX: Generate + # src/baselines3_models/approach_v0.cpp + src/baselines3_models/cartpole_v1.cpp ) add_library(baselines3_models SHARED ${ALL_SOURCES}) diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h index a8667fcaf..d62d82fb0 100644 --- a/cpp/include/baselines3_models/predictor.h +++ b/cpp/include/baselines3_models/predictor.h @@ -7,7 +7,8 @@ namespace baselines3_models { class Predictor { public: enum PolicyType { - ACTOR_MU + ACTOR_MU, + QNET_SCAN }; Predictor(std::string model_filename); @@ -16,6 +17,7 @@ class Predictor { virtual torch::Tensor preprocess_observation(torch::Tensor &observation); virtual torch::Tensor process_action(torch::Tensor &action); + virtual std::vector enumerate_actions(); protected: torch::jit::script::Module module; diff --git a/cpp/model_template.h b/cpp/model_template.h index d7d8f884a..5601638ed 100644 --- a/cpp/model_template.h +++ b/cpp/model_template.h @@ -8,7 +8,16 @@ class CLASS_NAME : public Predictor { public: CLASS_NAME(); + /** + * Observation space is: + * OBSERVATION_SPACE + */ torch::Tensor preprocess_observation(torch::Tensor &observation) override; + + /** + * Action space is: + * ACTION_SPACE + */ torch::Tensor process_action(torch::Tensor &action) override; }; } // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index 5c094d280..b3b3b18d7 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -22,11 +22,18 @@ torch::Tensor Predictor::predict(torch::Tensor &observation) { inputs.push_back(processed_observation); action = module.forward(inputs).toTensor(); + action = process_action(action); + } else if (policy_type == QNET_SCAN) { + std::vector inputs; + inputs.push_back(processed_observation); + + auto q_values = module.forward(inputs).toTensor(); + action = torch::argmax(q_values); } else { throw std::runtime_error("Unknown policy type"); } - return process_action(action); + return action; } torch::Tensor Predictor::preprocess_observation(torch::Tensor &observation) { @@ -37,4 +44,9 @@ torch::Tensor Predictor::process_action(torch::Tensor &action) { return action; } +std::vector Predictor::enumerate_actions() { + std::vector result; + return result; +} + } // namespace baselines3_models \ No newline at end of file diff --git a/cpp/src/predict.cpp b/cpp/src/predict.cpp index 326039a9e..fe2d69155 100644 --- a/cpp/src/predict.cpp +++ b/cpp/src/predict.cpp @@ -1,5 +1,6 @@ #include "baselines3_models/preprocessing.h" #include "baselines3_models/approach_v0.h" +#include "baselines3_models/cartpole_v1.h" #include #include #include @@ -8,11 +9,16 @@ using namespace baselines3_models; using namespace torch::indexing; -int main(int argc, const char *argv[]) { - approach_v0 approach; +int main(int argc, const char *argv[]) { + CartPole_v1 cartpole; - torch::Tensor observation = torch::tensor({-1., 0., 0., 1., 0., 1., 0., 0., 0.}); - torch::Tensor action = approach.predict(observation); + torch::Tensor observation = torch::tensor({0., 0., 0., 0.}); + cartpole.predict(observation); - std::cout << (action) << std::endl; + // approach_v0 approach; + + // torch::Tensor observation = torch::tensor({-1., 0., 0., 1., 0., 1., 0., 0., 0.}); + // torch::Tensor action = approach.predict(observation); + + // std::cout << (action) << std::endl; } \ No newline at end of file diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index 6b817dc66..b1479938f 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -5,6 +5,7 @@ from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.dqn.policies import DQNPolicy class CppExporter(object): @@ -19,6 +20,7 @@ def generate_observation_preprocessing(self): observation_space = self.model.env.observation_space policy = self.model.policy preprocess_observation = "" + self.vars["OBSERVATION_SPACE"] = repr(observation_space) if isinstance(observation_space, spaces.Box): if is_image_space(observation_space) and policy.normalize_images: @@ -38,19 +40,23 @@ def generate_observation_preprocessing(self): def generate_action_processing(self): action_space = self.model.env.action_space + process_action = "" + self.vars["ACTION_SPACE"] = repr(action_space) if isinstance(action_space, spaces.Box): if self.model.policy.squash_output: low_values = ",".join([f"(float){x}" for x in action_space.low]) - process_action = "torch::Tensor action_low = torch::tensor({%s});\n" % low_values + process_action += "torch::Tensor action_low = torch::tensor({%s});\n" % low_values high_values = ",".join([f"(float){x}" for x in action_space.high]) process_action += "torch::Tensor action_high = torch::tensor({%s});\n" % high_values process_action += "result = action_low + (0.5 * (action + 1.0) * (action_high - action_low));\n" else: - process_action = "result = action;\n" + process_action += "result = action;\n" + if isinstance(action_space, spaces.Discrete): + process_action += "result = action;\n" elif isinstance(action_space, spaces.Box): - process_action = "result = action;\n" + process_action += "result = action;\n" else: raise NotImplementedError(f"C++ exporting does not support processing action {action_space}") @@ -89,16 +95,24 @@ def export_model(self): if isinstance(policy, TD3Policy): traced_script_module = th.jit.trace(policy.actor.mu, policy.actor.extract_features(obs)) self.vars["POLICY_TYPE"] = "ACTOR_MU" - if isinstance(policy, ActorCriticPolicy) or isinstance(policy, SACPolicy): + elif isinstance(policy, SACPolicy): model = th.nn.Sequential(policy.actor.latent_pi, policy.actor.mu) traced_script_module = th.jit.trace(model, policy.actor.extract_features(obs)) self.vars["POLICY_TYPE"] = "ACTOR_MU" + elif isinstance(policy, ActorCriticPolicy): + model = th.nn.Sequential(policy.mlp_extractor.policy_net, policy.action_net) + traced_script_module = th.jit.trace(model, policy.extract_features(obs)) + self.vars["POLICY_TYPE"] = "ACTOR_MU" + elif isinstance(policy, DQNPolicy): + traced_script_module = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) + self.vars["POLICY_TYPE"] = "QNET_SCAN" else: raise NotImplementedError(f"C++ exporting does not support policy {policy}") - print(f"Generated {fname}") - traced_script_module.save(fname) - self.vars["MODEL_FNAME"] = asset_fname + if traced_script_module is not None: + print(f"Generated {fname}") + traced_script_module.save(fname) + self.vars["MODEL_FNAME"] = asset_fname def export(self): self.export_model() From 1b460839463367a982967066237ffa4b2dcb39f7 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Fri, 1 Apr 2022 14:45:47 +0200 Subject: [PATCH 04/14] CPP Export (+pybind to test) --- cpp/CMakeLists.txt | 43 +++++--- cpp/include/baselines3_models/predictor.h | 6 +- cpp/model_template.cpp | 18 ++++ cpp/model_template.h | 4 + cpp/src/baselines3_models/predictor.cpp | 19 ++-- cpp/src/predict.cpp | 20 ++-- utils/cpp_exporter.py | 118 ++++++++++++++++++++-- 7 files changed, 185 insertions(+), 43 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cafc2fd36..f10361c27 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -3,13 +3,12 @@ project(baselines3_models) include(cmake/CMakeRC.cmake) cmrc_add_resource_library(baselines3_model_resources - ALIAS baselines3_model::rc - NAMESPACE baselines3_model +ALIAS baselines3_model::rc +NAMESPACE baselines3_model - # XXX: Generate - assets/approach_v0_model.pt - assets/CartPole_v1_model.pt - ) +#static +#!static +) # Install PyTorch C++ first, see: https://pytorch.org/cppdocs/installing.html # Don't forget to add it to your CMAKE_PREFIX_PATH @@ -20,18 +19,22 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17 -Wall -Wextra -fPIC") option(BASELINES3_BIN "Building bin" OFF) +option(BASELINES3_PYBIND "Build python bindings (requires pybind11)" OFF) set(ALL_SOURCES - src/baselines3_models/predictor.cpp - src/baselines3_models/preprocessing.cpp +src/baselines3_models/predictor.cpp +src/baselines3_models/preprocessing.cpp + +#sources +#!sources +) - # XXX: Generate - # src/baselines3_models/approach_v0.cpp - src/baselines3_models/cartpole_v1.cpp +set(LIBRARIES +"${TORCH_LIBRARIES}" baselines3_model::rc ) add_library(baselines3_models SHARED ${ALL_SOURCES}) -target_link_libraries(baselines3_models "${TORCH_LIBRARIES}" baselines3_model::rc) +target_link_libraries(baselines3_models ${LIBRARIES}) target_include_directories(baselines3_models PUBLIC $ ) @@ -39,4 +42,20 @@ target_include_directories(baselines3_models PUBLIC if (BASELINES3_BIN) add_executable(predict ${CMAKE_CURRENT_SOURCE_DIR}/src/predict.cpp) target_link_libraries(predict baselines3_models) +endif() + +if (BASELINES3_PYBIND) + set (Python_EXECUTABLE "/usr/bin/python3.8") + # apt-get install python3-dev + find_package(Python COMPONENTS Interpreter Development) + # apt-get install python3-pybind11 + find_package(pybind11 REQUIRED) + + pybind11_add_module(baselines3_py ${ALL_SOURCES}) + target_link_libraries(baselines3_py PRIVATE ${LIBRARIES}) + target_compile_definitions(baselines3_py PUBLIC -DEXPORT_PYBIND) + + target_include_directories(baselines3_py PUBLIC + $ + ) endif() \ No newline at end of file diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h index d62d82fb0..1e2e74326 100644 --- a/cpp/include/baselines3_models/predictor.h +++ b/cpp/include/baselines3_models/predictor.h @@ -7,14 +7,18 @@ namespace baselines3_models { class Predictor { public: enum PolicyType { + // The module is an actor's µ: directly outputs action from state ACTOR_MU, - QNET_SCAN + // The network is a Q-Network: outputs Q(s,a) for all a for a given s + QNET_ALL }; Predictor(std::string model_filename); torch::Tensor predict(torch::Tensor &observation); + std::vector predict_vector(std::vector obs); + virtual torch::Tensor preprocess_observation(torch::Tensor &observation); virtual torch::Tensor process_action(torch::Tensor &action); virtual std::vector enumerate_actions(); diff --git a/cpp/model_template.cpp b/cpp/model_template.cpp index 42855a62f..cec275eec 100644 --- a/cpp/model_template.cpp +++ b/cpp/model_template.cpp @@ -1,5 +1,13 @@ +/*** + * This file was AUTOGENERATED by Stable Baselines3 Zoo + * https://github.com/DLR-RM/rl-baselines3-zoo + */ #include "baselines3_models/FILE_NAME.h" #include "baselines3_models/preprocessing.h" +#ifdef EXPORT_PYBIND +#include +#include +#endif namespace baselines3_models { @@ -19,4 +27,14 @@ torch::Tensor CLASS_NAME::process_action(torch::Tensor &action) { return result; } +#ifdef EXPORT_PYBIND +namespace py = pybind11; + +PYBIND11_MODULE(baselines3_py, m) { + py::class_(m, "CLASS_NAME") + .def(py::init()) + .def("predict", &CLASS_NAME::predict_vector); +} +#endif + } // namespace baselines3_models \ No newline at end of file diff --git a/cpp/model_template.h b/cpp/model_template.h index 5601638ed..d4bec8d82 100644 --- a/cpp/model_template.h +++ b/cpp/model_template.h @@ -1,3 +1,7 @@ +/*** + * This file was AUTOGENERATED by Stable Baselines3 Zoo + * https://github.com/DLR-RM/rl-baselines3-zoo + */ #pragma once #include "baselines3_models/predictor.h" diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index b3b3b18d7..d1c09ca5a 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -16,17 +16,13 @@ torch::Tensor Predictor::predict(torch::Tensor &observation) { c10::InferenceMode guard; torch::Tensor processed_observation = preprocess_observation(observation); at::Tensor action; + std::vector inputs; + inputs.push_back(processed_observation); if (policy_type == ACTOR_MU) { - std::vector inputs; - inputs.push_back(processed_observation); - action = module.forward(inputs).toTensor(); action = process_action(action); - } else if (policy_type == QNET_SCAN) { - std::vector inputs; - inputs.push_back(processed_observation); - + } else if (policy_type == QNET_ALL) { auto q_values = module.forward(inputs).toTensor(); action = torch::argmax(q_values); } else { @@ -36,6 +32,15 @@ torch::Tensor Predictor::predict(torch::Tensor &observation) { return action; } +std::vector Predictor::predict_vector(std::vector obs) { + torch::Tensor observation = torch::from_blob(obs.data(), obs.size()); + torch::Tensor action = predict(observation); + action = action.contiguous().to(torch::kFloat32); + std::vector result(action.data_ptr(), + action.data_ptr() + action.numel()); + return result; +} + torch::Tensor Predictor::preprocess_observation(torch::Tensor &observation) { return observation; } diff --git a/cpp/src/predict.cpp b/cpp/src/predict.cpp index fe2d69155..f854b79e9 100644 --- a/cpp/src/predict.cpp +++ b/cpp/src/predict.cpp @@ -1,24 +1,16 @@ -#include "baselines3_models/preprocessing.h" -#include "baselines3_models/approach_v0.h" +// This file is just a demonstration, you can adapt to test your model +// First, include your model: #include "baselines3_models/cartpole_v1.h" -#include -#include -#include -#include "cmrc/cmrc.hpp" using namespace baselines3_models; -using namespace torch::indexing; int main(int argc, const char *argv[]) { + // Create an instance of it: CartPole_v1 cartpole; + // Build an observation: torch::Tensor observation = torch::tensor({0., 0., 0., 0.}); - cartpole.predict(observation); - // approach_v0 approach; - - // torch::Tensor observation = torch::tensor({-1., 0., 0., 1., 0., 1., 0., 0., 0.}); - // torch::Tensor action = approach.predict(observation); - - // std::cout << (action) << std::endl; + // You can now check the prediction: + std::cout << cartpole.predict(observation) << std::endl; } \ No newline at end of file diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index b1479938f..42fca20a8 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -1,4 +1,7 @@ import torch as th +import re +import shutil +import os from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.preprocessing import is_image_space @@ -10,13 +13,45 @@ class CppExporter(object): def __init__(self, model: BaseAlgorithm, directory: str, name: str): + """ + C++ module exporter + + :param BaseAlgorithm model: The algorithm that should be exported + :param str directory: Output directory + :param str name: The module name + """ self.model = model self.directory = directory self.name = name.replace("-", "_") - self.template_directory = dir = "/".join(__file__.split("/")[:-1] + ["..", "cpp"]) + + # Templates directory is found relatively to this script (../cpp) + self.template_directory = "/".join(__file__.split("/")[:-1] + ["..", "cpp"]) self.vars = {} + # Name of the asset (.pt file) + self.asset_fname = None + self.cpp_fname = None + + def generate_directory(self): + """ + Generates the target directory if it doesn't exists + """ + + def ignore(directory, files): + if directory == self.template_directory: + return ['.gitignore', 'model_template.h', 'model_template.cpp'] + + return [] + + if not os.path.isdir(self.directory): + shutil.copytree(self.template_directory, self.directory, ignore=ignore) + def generate_observation_preprocessing(self): + """ + Generates observation preprocessing code for this model + + :raises NotImplementedError: If the observation space is not supported + """ observation_space = self.model.env.observation_space policy = self.model.policy preprocess_observation = "" @@ -24,12 +59,16 @@ def generate_observation_preprocessing(self): if isinstance(observation_space, spaces.Box): if is_image_space(observation_space) and policy.normalize_images: + # Normalizing image pixels preprocess_observation += "result = observation / 255.;\n" else: + # Keeping observation as it is preprocess_observation += "result = observation;\n" elif isinstance(observation_space, spaces.Discrete): + # Applying one hot representation preprocess_observation += f"result = torch::one_hot(observation, {observation_space.n});\n" elif isinstance(observation_space, spaces.MultiDiscrete): + # Applying multiple one hot representation (using C++ function) classes = ",".join(map(str, observation_space.nvec)) preprocess_observation += "torch::Tensor classes = torch::tensor({%s});\n" % classes preprocess_observation += f"result = multi_one_hot(observation, classes);\n" @@ -39,12 +78,19 @@ def generate_observation_preprocessing(self): self.vars["PREPROCESS_OBSERVATION"] = preprocess_observation def generate_action_processing(self): + """ + Generates the action post-processing + + :raises NotImplementedError: If the action space is not supported + """ action_space = self.model.env.action_space process_action = "" self.vars["ACTION_SPACE"] = repr(action_space) if isinstance(action_space, spaces.Box): if self.model.policy.squash_output: + # Unscaling the action assuming it lies in [-1, 1], since squash networks use Tanh as + # final activation functions low_values = ",".join([f"(float){x}" for x in action_space.low]) process_action += "torch::Tensor action_low = torch::tensor({%s});\n" % low_values high_values = ",".join([f"(float){x}" for x in action_space.high]) @@ -53,9 +99,8 @@ def generate_action_processing(self): process_action += "result = action_low + (0.5 * (action + 1.0) * (action_high - action_low));\n" else: process_action += "result = action;\n" - if isinstance(action_space, spaces.Discrete): - process_action += "result = action;\n" - elif isinstance(action_space, spaces.Box): + elif isinstance(action_space, spaces.Discrete) or isinstance(action_space, spaces.MultiDiscrete): + # Keeping input as it is process_action += "result = action;\n" else: raise NotImplementedError(f"C++ exporting does not support processing action {action_space}") @@ -63,11 +108,16 @@ def generate_action_processing(self): self.vars["PROCESS_ACTION"] = process_action def export_code(self): + """ + Export the C++ code + """ self.vars["CLASS_NAME"] = self.name fname = self.name.lower() + self.vars["FILE_NAME"] = fname + self.cpp_fname = f"src/baselines3_models/{fname}.cpp" target_header = self.directory + f"/include/baselines3_models/{fname}.h" - target_cpp = self.directory + f"/src/baselines3_models/{fname}.cpp" + target_cpp = self.directory + "/" + self.cpp_fname self.generate_observation_preprocessing() self.generate_action_processing() @@ -76,6 +126,12 @@ def export_code(self): self.render("model_template.cpp", target_cpp) def render(self, template: str, target: str): + """ + Renders some template, replacing self.vars variables by their values + + :param str template: The template name + :param str target: The target file + """ with open(self.template_directory + "/" + template, "r") as template_f: with open(target, "w") as target_f: data = template_f.read() @@ -85,11 +141,53 @@ def render(self, template: str, target: str): print("Generated " + target) + def update_cmake(self): + """ + Updates the target's CMakeLists.txt, adding files in static and sources section + + :raises ValueError: If a section can't be found in the CMakeLists + """ + cmake_contents = open(self.directory + "/CMakeLists.txt", "r").read() + + def add_to_section(section_name:str, fname:str, contents:str): + pattern = f"#{section_name}(.+)#!{section_name}" + flags=re.MULTILINE + re.DOTALL + + match = re.search(pattern, cmake_contents, flags=flags) + + if match is None: + raise ValueError(f"Couldn't find {section_name} section in CMakeLists.txt") + + files = match[1].strip() + if files: + files = list(map(str.strip, files.split("\n"))) + else: + files = [] + + if fname not in files: + print(f"Adding {fname} to CMake {section_name}") + files.append(fname) + + new_section = f"#{section_name}\n" + ("\n".join(files)) + "\n" + f"#!{section_name}" + + return re.sub(pattern, new_section, contents, flags=flags) + + cmake_contents = add_to_section("static", self.asset_fname, cmake_contents) + cmake_contents = add_to_section("sources", self.cpp_fname, cmake_contents) + + with open(self.directory + "/CMakeLists.txt", "w") as f: + f.write(cmake_contents) + def export_model(self): + """ + Export the Algorithm's model using Pytorch's JIT script tracer + + :raises NotImplementedError: If the policy is not supported + """ policy = self.model.policy obs = th.Tensor(self.model.env.reset()) - asset_fname = f"assets/{self.name}_model.pt" - fname = self.directory + "/" + asset_fname + self.asset_fname = f"assets/{self.name.lower()}_model.pt" + fname = self.directory + "/" + self.asset_fname traced_script_module = None if isinstance(policy, TD3Policy): @@ -105,15 +203,17 @@ def export_model(self): self.vars["POLICY_TYPE"] = "ACTOR_MU" elif isinstance(policy, DQNPolicy): traced_script_module = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) - self.vars["POLICY_TYPE"] = "QNET_SCAN" + self.vars["POLICY_TYPE"] = "QNET_ALL" else: raise NotImplementedError(f"C++ exporting does not support policy {policy}") if traced_script_module is not None: print(f"Generated {fname}") traced_script_module.save(fname) - self.vars["MODEL_FNAME"] = asset_fname + self.vars["MODEL_FNAME"] = self.asset_fname def export(self): + self.generate_directory() self.export_model() self.export_code() + self.update_cmake() From ef4b1ef04e50f903b6caf3121fc9004da103a97c Mon Sep 17 00:00:00 2001 From: Gregwar Date: Fri, 1 Apr 2022 14:49:13 +0200 Subject: [PATCH 05/14] Reformating --- utils/cpp_exporter.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index 42fca20a8..4d38674eb 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -1,14 +1,15 @@ -import torch as th +import os import re import shutil -import os + +import torch as th from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.preprocessing import is_image_space from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.td3.policies import TD3Policy -from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.common.preprocessing import is_image_space from stable_baselines3.dqn.policies import DQNPolicy +from stable_baselines3.sac.policies import SACPolicy +from stable_baselines3.td3.policies import TD3Policy class CppExporter(object): @@ -39,7 +40,7 @@ def generate_directory(self): def ignore(directory, files): if directory == self.template_directory: - return ['.gitignore', 'model_template.h', 'model_template.cpp'] + return [".gitignore", "model_template.h", "model_template.cpp"] return [] @@ -146,12 +147,12 @@ def update_cmake(self): Updates the target's CMakeLists.txt, adding files in static and sources section :raises ValueError: If a section can't be found in the CMakeLists - """ + """ cmake_contents = open(self.directory + "/CMakeLists.txt", "r").read() - def add_to_section(section_name:str, fname:str, contents:str): + def add_to_section(section_name: str, fname: str, contents: str): pattern = f"#{section_name}(.+)#!{section_name}" - flags=re.MULTILINE + re.DOTALL + flags = re.MULTILINE + re.DOTALL match = re.search(pattern, cmake_contents, flags=flags) @@ -163,13 +164,13 @@ def add_to_section(section_name:str, fname:str, contents:str): files = list(map(str.strip, files.split("\n"))) else: files = [] - + if fname not in files: print(f"Adding {fname} to CMake {section_name}") files.append(fname) new_section = f"#{section_name}\n" + ("\n".join(files)) + "\n" + f"#!{section_name}" - + return re.sub(pattern, new_section, contents, flags=flags) cmake_contents = add_to_section("static", self.asset_fname, cmake_contents) From de9d721a4ba7f6484cc2eaccbec2d2ce983c3fdd Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 18:06:46 +0200 Subject: [PATCH 06/14] Outputing value function (wip) --- cpp/include/baselines3_models/predictor.h | 16 ++++-- cpp/model_template.cpp | 2 +- cpp/src/baselines3_models/predictor.cpp | 69 +++++++++++++++++++---- utils/cpp_exporter.py | 65 +++++++++++++++------ 4 files changed, 118 insertions(+), 34 deletions(-) diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h index 1e2e74326..ce20fb097 100644 --- a/cpp/include/baselines3_models/predictor.h +++ b/cpp/include/baselines3_models/predictor.h @@ -7,15 +7,19 @@ namespace baselines3_models { class Predictor { public: enum PolicyType { - // The module is an actor's µ: directly outputs action from state - ACTOR_MU, + // The first network is an actor and the second a value network + ACTOR_VALUE, + ACTOR_VALUE_DISCRETE, + // The first network is an actor and the second a Q network + ACTOR_Q, // The network is a Q-Network: outputs Q(s,a) for all a for a given s QNET_ALL }; - Predictor(std::string model_filename); + Predictor(std::string actor_filename, std::string q_filename, std::string v_filename); - torch::Tensor predict(torch::Tensor &observation); + torch::Tensor predict(torch::Tensor &observation, bool unscale_action = true); + double value(torch::Tensor &observation); std::vector predict_vector(std::vector obs); @@ -24,7 +28,9 @@ class Predictor { virtual std::vector enumerate_actions(); protected: - torch::jit::script::Module module; + torch::jit::script::Module model_actor; + torch::jit::script::Module model_q; + torch::jit::script::Module model_v; PolicyType policy_type; }; } // namespace baselines3_models \ No newline at end of file diff --git a/cpp/model_template.cpp b/cpp/model_template.cpp index cec275eec..276429ae4 100644 --- a/cpp/model_template.cpp +++ b/cpp/model_template.cpp @@ -11,7 +11,7 @@ namespace baselines3_models { -CLASS_NAME::CLASS_NAME() : Predictor("MODEL_FNAME") { +CLASS_NAME::CLASS_NAME() : Predictor("MODEL_ACTOR", "MODEL_Q", "MODEL_V") { policy_type = POLICY_TYPE; } diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index d1c09ca5a..fa7b7213c 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -4,26 +4,44 @@ CMRC_DECLARE(baselines3_model); namespace baselines3_models { -Predictor::Predictor(std::string model_filename) { - auto fs = cmrc::baselines3_model::get_filesystem(); - auto f = fs.open(model_filename); - std::string data(f.begin(), f.end()); - std::istringstream stream(data); - module = torch::jit::load(stream); + +static void _load_model(std::string filename, + torch::jit::script::Module &model) { + if (filename != "") { + auto fs = cmrc::baselines3_model::get_filesystem(); + auto f = fs.open(filename); + std::string data(f.begin(), f.end()); + std::istringstream stream(data); + model = torch::jit::load(stream); + } +} + +Predictor::Predictor(std::string actor_filename, std::string q_filename, + std::string v_filename) { + + _load_model(actor_filename, model_actor); + _load_model(q_filename, model_q); + _load_model(v_filename, model_v); } -torch::Tensor Predictor::predict(torch::Tensor &observation) { +torch::Tensor Predictor::predict(torch::Tensor &observation, + bool unscale_action) { c10::InferenceMode guard; torch::Tensor processed_observation = preprocess_observation(observation); at::Tensor action; std::vector inputs; inputs.push_back(processed_observation); - if (policy_type == ACTOR_MU) { - action = module.forward(inputs).toTensor(); - action = process_action(action); + if (policy_type == ACTOR_Q || policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + action = model_actor.forward(inputs).toTensor(); + if (unscale_action) { + action = process_action(action); + } + if (policy_type == ACTOR_VALUE_DISCRETE) { + action = torch::argmax(action); + } } else if (policy_type == QNET_ALL) { - auto q_values = module.forward(inputs).toTensor(); + auto q_values = model_q.forward(inputs).toTensor(); action = torch::argmax(q_values); } else { throw std::runtime_error("Unknown policy type"); @@ -32,6 +50,35 @@ torch::Tensor Predictor::predict(torch::Tensor &observation) { return action; } +double Predictor::value(torch::Tensor &observation) { + double value = 0.0; + + torch::Tensor processed_observation = preprocess_observation(observation); + at::Tensor action; + std::vector inputs; + + if (policy_type == ACTOR_Q) { + auto action = predict(observation, false); + std::vector tensor_vec{ processed_observation, action }; + inputs.push_back(torch::cat({ tensor_vec })); + + auto q = model_q.forward(inputs).toTensor(); + value = q.data_ptr()[0]; + } else if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + inputs.push_back(processed_observation); + auto v = model_v.forward(inputs).toTensor(); + value = v.data_ptr()[0]; + } else if (policy_type == QNET_ALL) { + inputs.push_back(processed_observation); + auto q = model_q.forward(inputs).toTensor(); + value = torch::max(q).data_ptr()[0]; + } else { + throw std::runtime_error("Unknown policy type"); + } + + return value; +} + std::vector Predictor::predict_vector(std::vector obs) { torch::Tensor observation = torch::from_blob(obs.data(), obs.size()); torch::Tensor action = predict(observation); diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index 4d38674eb..6521286b3 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -30,7 +30,7 @@ def __init__(self, model: BaseAlgorithm, directory: str, name: str): self.vars = {} # Name of the asset (.pt file) - self.asset_fname = None + self.asset_fnames = [] self.cpp_fname = None def generate_directory(self): @@ -173,7 +173,8 @@ def add_to_section(section_name: str, fname: str, contents: str): return re.sub(pattern, new_section, contents, flags=flags) - cmake_contents = add_to_section("static", self.asset_fname, cmake_contents) + for asset in self.asset_fnames: + cmake_contents = add_to_section("static", asset, cmake_contents) cmake_contents = add_to_section("sources", self.cpp_fname, cmake_contents) with open(self.directory + "/CMakeLists.txt", "w") as f: @@ -187,31 +188,61 @@ def export_model(self): """ policy = self.model.policy obs = th.Tensor(self.model.env.reset()) - self.asset_fname = f"assets/{self.name.lower()}_model.pt" - fname = self.directory + "/" + self.asset_fname - traced_script_module = None + + def get_fname(suffix): + asset_fname = f"assets/{self.name.lower()}_{suffix}.pt" + fname = self.directory + "/" + asset_fname + return asset_fname, fname + + traced = { + 'actor': None, + 'q': None, + 'v': None, + } if isinstance(policy, TD3Policy): - traced_script_module = th.jit.trace(policy.actor.mu, policy.actor.extract_features(obs)) - self.vars["POLICY_TYPE"] = "ACTOR_MU" + features = policy.actor.extract_features(obs) + traced['actor'] = th.jit.trace(policy.actor.mu, features) + + action = policy.actor.mu(features) + traced['q'] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, SACPolicy): + features = policy.actor.extract_features(obs) model = th.nn.Sequential(policy.actor.latent_pi, policy.actor.mu) - traced_script_module = th.jit.trace(model, policy.actor.extract_features(obs)) - self.vars["POLICY_TYPE"] = "ACTOR_MU" + traced['actor'] = th.jit.trace(model, features) + + action = model(features) + traced['q'] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, ActorCriticPolicy): - model = th.nn.Sequential(policy.mlp_extractor.policy_net, policy.action_net) - traced_script_module = th.jit.trace(model, policy.extract_features(obs)) - self.vars["POLICY_TYPE"] = "ACTOR_MU" + actor_model = th.nn.Sequential(policy.mlp_extractor.policy_net, policy.action_net) + traced['actor'] = th.jit.trace(actor_model, policy.extract_features(obs)) + + # action = policy.predict(obs) + value_model = th.nn.Sequential(policy.mlp_extractor.value_net, policy.value_net) + traced['v'] = th.jit.trace(value_model, policy.extract_features(obs)) + + if isinstance(self.model.env.action_space, spaces.Discrete): + self.vars["POLICY_TYPE"] = "ACTOR_VALUE_DISCRETE" + else: + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, DQNPolicy): - traced_script_module = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) + traced['q'] = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) self.vars["POLICY_TYPE"] = "QNET_ALL" else: raise NotImplementedError(f"C++ exporting does not support policy {policy}") - if traced_script_module is not None: - print(f"Generated {fname}") - traced_script_module.save(fname) - self.vars["MODEL_FNAME"] = self.asset_fname + for entry in traced.keys(): + var = f"MODEL_{entry.upper()}" + if traced[entry] is None: + self.vars[var] = '' + else: + asset_fname, fname = get_fname(entry) + traced[entry].save(fname) + print(f"Generated {fname}") + self.asset_fnames.append(asset_fname) + self.vars[var] = asset_fname def export(self): self.generate_directory() From e758ce64843600e38cb11309cd617ee2072656df Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 18:08:14 +0200 Subject: [PATCH 07/14] Formatting --- utils/cpp_exporter.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index 6521286b3..b39c13d67 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -188,47 +188,47 @@ def export_model(self): """ policy = self.model.policy obs = th.Tensor(self.model.env.reset()) - + def get_fname(suffix): asset_fname = f"assets/{self.name.lower()}_{suffix}.pt" fname = self.directory + "/" + asset_fname return asset_fname, fname traced = { - 'actor': None, - 'q': None, - 'v': None, + "actor": None, + "q": None, + "v": None, } if isinstance(policy, TD3Policy): features = policy.actor.extract_features(obs) - traced['actor'] = th.jit.trace(policy.actor.mu, features) + traced["actor"] = th.jit.trace(policy.actor.mu, features) action = policy.actor.mu(features) - traced['q'] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, SACPolicy): features = policy.actor.extract_features(obs) model = th.nn.Sequential(policy.actor.latent_pi, policy.actor.mu) - traced['actor'] = th.jit.trace(model, features) + traced["actor"] = th.jit.trace(model, features) action = model(features) - traced['q'] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, ActorCriticPolicy): actor_model = th.nn.Sequential(policy.mlp_extractor.policy_net, policy.action_net) - traced['actor'] = th.jit.trace(actor_model, policy.extract_features(obs)) + traced["actor"] = th.jit.trace(actor_model, policy.extract_features(obs)) # action = policy.predict(obs) value_model = th.nn.Sequential(policy.mlp_extractor.value_net, policy.value_net) - traced['v'] = th.jit.trace(value_model, policy.extract_features(obs)) + traced["v"] = th.jit.trace(value_model, policy.extract_features(obs)) if isinstance(self.model.env.action_space, spaces.Discrete): self.vars["POLICY_TYPE"] = "ACTOR_VALUE_DISCRETE" else: self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, DQNPolicy): - traced['q'] = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) + traced["q"] = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) self.vars["POLICY_TYPE"] = "QNET_ALL" else: raise NotImplementedError(f"C++ exporting does not support policy {policy}") @@ -236,7 +236,7 @@ def get_fname(suffix): for entry in traced.keys(): var = f"MODEL_{entry.upper()}" if traced[entry] is None: - self.vars[var] = '' + self.vars[var] = "" else: asset_fname, fname = get_fname(entry) traced[entry].save(fname) From a0dcb776e8456d4cf82a8d6479572fc0b872b364 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 18:45:38 +0200 Subject: [PATCH 08/14] Adding feature extractors in exported models (not tested, wip) --- utils/cpp_exporter.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index b39c13d67..6703d25c0 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -6,7 +6,7 @@ from gym import spaces from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.preprocessing import is_image_space, preprocess_obs from stable_baselines3.dqn.policies import DQNPolicy from stable_baselines3.sac.policies import SACPolicy from stable_baselines3.td3.policies import TD3Policy @@ -200,35 +200,38 @@ def get_fname(suffix): "v": None, } + obs = preprocess_obs(obs, self.model.env.observation_space) + if isinstance(policy, TD3Policy): features = policy.actor.extract_features(obs) - traced["actor"] = th.jit.trace(policy.actor.mu, features) + model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.mu) + traced["actor"] = th.jit.trace(model, obs) action = policy.actor.mu(features) traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, SACPolicy): features = policy.actor.extract_features(obs) - model = th.nn.Sequential(policy.actor.latent_pi, policy.actor.mu) - traced["actor"] = th.jit.trace(model, features) + model = th.nn.Sequential(obs, policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) + traced["actor"] = th.jit.trace(model, obs) action = model(features) traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, ActorCriticPolicy): - actor_model = th.nn.Sequential(policy.mlp_extractor.policy_net, policy.action_net) - traced["actor"] = th.jit.trace(actor_model, policy.extract_features(obs)) + actor_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.policy_net, policy.action_net) + traced["actor"] = th.jit.trace(actor_model, obs) - # action = policy.predict(obs) value_model = th.nn.Sequential(policy.mlp_extractor.value_net, policy.value_net) - traced["v"] = th.jit.trace(value_model, policy.extract_features(obs)) + traced["v"] = th.jit.trace(value_model, obs) if isinstance(self.model.env.action_space, spaces.Discrete): self.vars["POLICY_TYPE"] = "ACTOR_VALUE_DISCRETE" else: self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, DQNPolicy): - traced["q"] = th.jit.trace(policy.q_net.q_net, policy.q_net.extract_features(obs)) + q_model = th.nn.Sequential(policy.q_net.features_extractor, policy.q_net.q_net) + traced["q"] = th.jit.trace(q_model, obs) self.vars["POLICY_TYPE"] = "QNET_ALL" else: raise NotImplementedError(f"C++ exporting does not support policy {policy}") From 87f0aa975f57deda36e9128f18c5e830c3175351 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 19:01:51 +0200 Subject: [PATCH 09/14] Unqqueezing since feature extractor is now embedded in model (and starts with flatten) --- cpp/src/baselines3_models/predictor.cpp | 8 ++++---- utils/cpp_exporter.py | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index fa7b7213c..5f7a5b127 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -30,7 +30,7 @@ torch::Tensor Predictor::predict(torch::Tensor &observation, torch::Tensor processed_observation = preprocess_observation(observation); at::Tensor action; std::vector inputs; - inputs.push_back(processed_observation); + inputs.push_back(processed_observation.unsqueeze(0)); if (policy_type == ACTOR_Q || policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { action = model_actor.forward(inputs).toTensor(); @@ -60,16 +60,16 @@ double Predictor::value(torch::Tensor &observation) { if (policy_type == ACTOR_Q) { auto action = predict(observation, false); std::vector tensor_vec{ processed_observation, action }; - inputs.push_back(torch::cat({ tensor_vec })); + inputs.push_back(torch::cat({ tensor_vec }).unsqueeze(0)); auto q = model_q.forward(inputs).toTensor(); value = q.data_ptr()[0]; } else if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { - inputs.push_back(processed_observation); + inputs.push_back(processed_observation.unsqueeze(0)); auto v = model_v.forward(inputs).toTensor(); value = v.data_ptr()[0]; } else if (policy_type == QNET_ALL) { - inputs.push_back(processed_observation); + inputs.push_back(processed_observation.unsqueeze(0)); auto q = model_q.forward(inputs).toTensor(); value = torch::max(q).data_ptr()[0]; } else { diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index 6703d25c0..cb3a90664 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -216,13 +216,14 @@ def get_fname(suffix): traced["actor"] = th.jit.trace(model, obs) action = model(features) - traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) + traced["q"] = th.jit.trace(q_model, th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, ActorCriticPolicy): actor_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.policy_net, policy.action_net) traced["actor"] = th.jit.trace(actor_model, obs) - value_model = th.nn.Sequential(policy.mlp_extractor.value_net, policy.value_net) + value_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.value_net, policy.value_net) traced["v"] = th.jit.trace(value_model, obs) if isinstance(self.model.env.action_space, spaces.Discrete): From 3d38afd6fbe289a4fbf556273e0f05f9f43b730e Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 23:35:06 +0200 Subject: [PATCH 10/14] Features for TD3 --- utils/cpp_exporter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index cb3a90664..a141d8a50 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -208,7 +208,8 @@ def get_fname(suffix): traced["actor"] = th.jit.trace(model, obs) action = policy.actor.mu(features) - traced["q"] = th.jit.trace(policy.critic.q_networks[0], th.cat([features, action], dim=1)) + q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) + traced["q"] = th.jit.trace(q_model, th.cat([features, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, SACPolicy): features = policy.actor.extract_features(obs) From 29dedc09fff3832e771638f7ed585c1a142c979b Mon Sep 17 00:00:00 2001 From: Gregwar Date: Wed, 6 Apr 2022 23:38:20 +0200 Subject: [PATCH 11/14] Features extractors wip --- utils/cpp_exporter.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index a141d8a50..becf1e09c 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -203,22 +203,20 @@ def get_fname(suffix): obs = preprocess_obs(obs, self.model.env.observation_space) if isinstance(policy, TD3Policy): - features = policy.actor.extract_features(obs) model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.mu) traced["actor"] = th.jit.trace(model, obs) - action = policy.actor.mu(features) + action = policy.actor.mu(policy.actor.extract_features(obs)) q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) - traced["q"] = th.jit.trace(q_model, th.cat([features, action], dim=1)) + traced["q"] = th.jit.trace(q_model, th.cat([obs, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, SACPolicy): - features = policy.actor.extract_features(obs) - model = th.nn.Sequential(obs, policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) + model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) traced["actor"] = th.jit.trace(model, obs) - action = model(features) + action = model(obs) q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) - traced["q"] = th.jit.trace(q_model, th.cat([features, action], dim=1)) + traced["q"] = th.jit.trace(q_model, th.cat([obs, action], dim=1)) self.vars["POLICY_TYPE"] = "ACTOR_Q" elif isinstance(policy, ActorCriticPolicy): actor_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.policy_net, policy.action_net) From e21cc8b476df024f6d75700067763b05947bb818 Mon Sep 17 00:00:00 2001 From: Gregwar Date: Thu, 7 Apr 2022 22:50:42 +0200 Subject: [PATCH 12/14] Tracing value modules --- cpp/include/baselines3_models/predictor.h | 4 +- cpp/src/baselines3_models/predictor.cpp | 11 +--- utils/cpp_exporter.py | 71 +++++++++++++++++------ 3 files changed, 57 insertions(+), 29 deletions(-) diff --git a/cpp/include/baselines3_models/predictor.h b/cpp/include/baselines3_models/predictor.h index ce20fb097..77d023492 100644 --- a/cpp/include/baselines3_models/predictor.h +++ b/cpp/include/baselines3_models/predictor.h @@ -7,11 +7,9 @@ namespace baselines3_models { class Predictor { public: enum PolicyType { - // The first network is an actor and the second a value network + // If we have an actor network and a value network ACTOR_VALUE, ACTOR_VALUE_DISCRETE, - // The first network is an actor and the second a Q network - ACTOR_Q, // The network is a Q-Network: outputs Q(s,a) for all a for a given s QNET_ALL }; diff --git a/cpp/src/baselines3_models/predictor.cpp b/cpp/src/baselines3_models/predictor.cpp index 5f7a5b127..429ccedba 100644 --- a/cpp/src/baselines3_models/predictor.cpp +++ b/cpp/src/baselines3_models/predictor.cpp @@ -32,7 +32,7 @@ torch::Tensor Predictor::predict(torch::Tensor &observation, std::vector inputs; inputs.push_back(processed_observation.unsqueeze(0)); - if (policy_type == ACTOR_Q || policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { action = model_actor.forward(inputs).toTensor(); if (unscale_action) { action = process_action(action); @@ -57,14 +57,7 @@ double Predictor::value(torch::Tensor &observation) { at::Tensor action; std::vector inputs; - if (policy_type == ACTOR_Q) { - auto action = predict(observation, false); - std::vector tensor_vec{ processed_observation, action }; - inputs.push_back(torch::cat({ tensor_vec }).unsqueeze(0)); - - auto q = model_q.forward(inputs).toTensor(); - value = q.data_ptr()[0]; - } else if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { + if (policy_type == ACTOR_VALUE || policy_type == ACTOR_VALUE_DISCRETE) { inputs.push_back(processed_observation.unsqueeze(0)); auto v = model_v.forward(inputs).toTensor(); value = v.data_ptr()[0]; diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index becf1e09c..cc04f5327 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -89,17 +89,19 @@ def generate_action_processing(self): self.vars["ACTION_SPACE"] = repr(action_space) if isinstance(action_space, spaces.Box): + # Handling action clipping + low_values = ",".join([f"(float){x}" for x in action_space.low]) + process_action += "torch::Tensor action_low = torch::tensor({%s});\n" % low_values + high_values = ",".join([f"(float){x}" for x in action_space.high]) + process_action += "torch::Tensor action_high = torch::tensor({%s});\n" % high_values + if self.model.policy.squash_output: # Unscaling the action assuming it lies in [-1, 1], since squash networks use Tanh as # final activation functions - low_values = ",".join([f"(float){x}" for x in action_space.low]) - process_action += "torch::Tensor action_low = torch::tensor({%s});\n" % low_values - high_values = ",".join([f"(float){x}" for x in action_space.high]) - process_action += "torch::Tensor action_high = torch::tensor({%s});\n" % high_values - process_action += "result = action_low + (0.5 * (action + 1.0) * (action_high - action_low));\n" else: - process_action += "result = action;\n" + # Clipping not squashed action + process_action += "result = torch::clip(action, action_low, action_high);\n" elif isinstance(action_space, spaces.Discrete) or isinstance(action_space, spaces.MultiDiscrete): # Keeping input as it is process_action += "result = action;\n" @@ -203,25 +205,58 @@ def get_fname(suffix): obs = preprocess_obs(obs, self.model.env.observation_space) if isinstance(policy, TD3Policy): - model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.mu) - traced["actor"] = th.jit.trace(model, obs) + # Actor extract features and apply mu + actor_model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.mu) + traced["actor"] = th.jit.trace(actor_model, obs) + + # Value function is a combination of actor and Q + class TD3PolicyValue(th.nn.Module): + def __init__(self, policy : TD3Policy, actor_model: th.nn.Module): + super(TD3PolicyValue, self).__init__() + + self.actor = actor_model + self.critic = policy.critic + + def forward(self, obs): + action = self.actor_model(obs) + critic_features = self.critic.features_extractor(obs) + return self.critic.q_networks[0](th.cat([critic_features, action], dim=1)) action = policy.actor.mu(policy.actor.extract_features(obs)) - q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) - traced["q"] = th.jit.trace(q_model, th.cat([obs, action], dim=1)) - self.vars["POLICY_TYPE"] = "ACTOR_Q" + v_model = TD3PolicyValue(policy, actor_model) + traced["v"] = th.jit.trace(v_model, obs) + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, SACPolicy): - model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) - traced["actor"] = th.jit.trace(model, obs) + # Feature extractor, latent pi and mu + if self.model.use_sde: + # XXX: Check for bijector ? + actor_model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) + else: + actor_model = th.nn.Sequential(policy.actor.features_extractor, + policy.actor.latent_pi, policy.actor.mu, th.nn.Tanh()) + traced["actor"] = th.jit.trace(actor_model, obs) + + class SACPolicyValue(th.nn.Module): + def __init__(self, policy : SACPolicy, actor_model: th.nn.Module): + super(SACPolicyValue, self).__init__() + + self.actor_model = actor_model + self.critic = policy.critic + + def forward(self, obs): + action = self.actor_model(obs) + critic_features = self.critic.features_extractor(obs) + return self.critic.q_networks[0](th.cat([critic_features, action], dim=1)) - action = model(obs) - q_model = th.nn.Sequential(policy.critic.features_extractor, policy.critic.q_networks[0]) - traced["q"] = th.jit.trace(q_model, th.cat([obs, action], dim=1)) - self.vars["POLICY_TYPE"] = "ACTOR_Q" + v_model = SACPolicyValue(policy, actor_model) + traced["v"] = th.jit.trace(v_model, obs) + self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, ActorCriticPolicy): + # Actor is feature extractor, mpl and action net actor_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.policy_net, policy.action_net) traced["actor"] = th.jit.trace(actor_model, obs) + # The value network is computed directly in ActorCriticPolicy (and not the Q network) value_model = th.nn.Sequential(policy.features_extractor, policy.mlp_extractor.value_net, policy.value_net) traced["v"] = th.jit.trace(value_model, obs) @@ -230,6 +265,8 @@ def get_fname(suffix): else: self.vars["POLICY_TYPE"] = "ACTOR_VALUE" elif isinstance(policy, DQNPolicy): + # For DQN, we only use one Q network that outputs Q(s,a) for all possible actions, it is then + # both used for action prediction using argmax and for value prediction q_model = th.nn.Sequential(policy.q_net.features_extractor, policy.q_net.q_net) traced["q"] = th.jit.trace(q_model, obs) self.vars["POLICY_TYPE"] = "QNET_ALL" From b4195ff82859e90957467840f88a72f3ecae8f5f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 8 Apr 2022 18:12:07 +0200 Subject: [PATCH 13/14] Reformat and minor cleanup --- enjoy.py | 2 +- utils/cpp_exporter.py | 44 ++++++++++++++++++++++++------------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/enjoy.py b/enjoy.py index 773aad4c4..74d347e48 100644 --- a/enjoy.py +++ b/enjoy.py @@ -11,8 +11,8 @@ import utils.import_envs # noqa: F401 pylint: disable=unused-import from utils import ALGOS, create_test_env, get_latest_run_id, get_saved_hyperparams -from utils.exp_manager import ExperimentManager from utils.cpp_exporter import CppExporter +from utils.exp_manager import ExperimentManager from utils.utils import StoreDict diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index cc04f5327..d67d754bd 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -1,6 +1,9 @@ import os import re import shutil +from typing import List +from pathlib import Path + import torch as th from gym import spaces @@ -17,28 +20,28 @@ def __init__(self, model: BaseAlgorithm, directory: str, name: str): """ C++ module exporter - :param BaseAlgorithm model: The algorithm that should be exported - :param str directory: Output directory - :param str name: The module name + :param model: The algorithm that should be exported + :param directory: Output directory + :param name: The module name """ self.model = model self.directory = directory self.name = name.replace("-", "_") # Templates directory is found relatively to this script (../cpp) - self.template_directory = "/".join(__file__.split("/")[:-1] + ["..", "cpp"]) + self.template_directory = str(Path(__file__).parent.parent / Path("cpp")) self.vars = {} # Name of the asset (.pt file) self.asset_fnames = [] self.cpp_fname = None - def generate_directory(self): + def generate_directory(self) -> None: """ Generates the target directory if it doesn't exists """ - def ignore(directory, files): + def ignore(directory: str, files: List[str]) -> List[str]: if directory == self.template_directory: return [".gitignore", "model_template.h", "model_template.cpp"] @@ -71,8 +74,8 @@ def generate_observation_preprocessing(self): elif isinstance(observation_space, spaces.MultiDiscrete): # Applying multiple one hot representation (using C++ function) classes = ",".join(map(str, observation_space.nvec)) - preprocess_observation += "torch::Tensor classes = torch::tensor({%s});\n" % classes - preprocess_observation += f"result = multi_one_hot(observation, classes);\n" + preprocess_observation += f"torch::Tensor classes = torch::tensor({classes});\n" + preprocess_observation += "result = multi_one_hot(observation, classes);\n" else: raise NotImplementedError(f"C++ exporting does not support observation {observation_space}") @@ -118,9 +121,10 @@ def export_code(self): fname = self.name.lower() self.vars["FILE_NAME"] = fname - self.cpp_fname = f"src/baselines3_models/{fname}.cpp" - target_header = self.directory + f"/include/baselines3_models/{fname}.h" - target_cpp = self.directory + "/" + self.cpp_fname + self.cpp_fname = os.path.join("src", "baselines3_models", f"{fname}.cpp") + include_fname = os.path.join("include", "baselines3_models", f"{fname}.h") + target_header = os.path.join(self.directory, include_fname) + target_cpp = os.path.join(self.directory, self.cpp_fname) self.generate_observation_preprocessing() self.generate_action_processing() @@ -191,9 +195,9 @@ def export_model(self): policy = self.model.policy obs = th.Tensor(self.model.env.reset()) - def get_fname(suffix): - asset_fname = f"assets/{self.name.lower()}_{suffix}.pt" - fname = self.directory + "/" + asset_fname + def get_fname(suffix: str): + asset_fname = os.path.join("assets", f"{self.name.lower()}_{suffix}.pt") + fname = os.path.join(self.directory, asset_fname) return asset_fname, fname traced = { @@ -211,7 +215,7 @@ def get_fname(suffix): # Value function is a combination of actor and Q class TD3PolicyValue(th.nn.Module): - def __init__(self, policy : TD3Policy, actor_model: th.nn.Module): + def __init__(self, policy: TD3Policy, actor_model: th.nn.Module): super(TD3PolicyValue, self).__init__() self.actor = actor_model @@ -222,6 +226,7 @@ def forward(self, obs): critic_features = self.critic.features_extractor(obs) return self.critic.q_networks[0](th.cat([critic_features, action], dim=1)) + # Note(antonin): unused variable action action = policy.actor.mu(policy.actor.extract_features(obs)) v_model = TD3PolicyValue(policy, actor_model) traced["v"] = th.jit.trace(v_model, obs) @@ -232,12 +237,13 @@ def forward(self, obs): # XXX: Check for bijector ? actor_model = th.nn.Sequential(policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu) else: - actor_model = th.nn.Sequential(policy.actor.features_extractor, - policy.actor.latent_pi, policy.actor.mu, th.nn.Tanh()) + actor_model = th.nn.Sequential( + policy.actor.features_extractor, policy.actor.latent_pi, policy.actor.mu, th.nn.Tanh() + ) traced["actor"] = th.jit.trace(actor_model, obs) class SACPolicyValue(th.nn.Module): - def __init__(self, policy : SACPolicy, actor_model: th.nn.Module): + def __init__(self, policy: SACPolicy, actor_model: th.nn.Module): super(SACPolicyValue, self).__init__() self.actor_model = actor_model @@ -284,7 +290,7 @@ def forward(self, obs): self.asset_fnames.append(asset_fname) self.vars[var] = asset_fname - def export(self): + def export(self) -> None: self.generate_directory() self.export_model() self.export_code() From f7b8c1577d2a5480646d259584a6be7ffdcda8aa Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 8 Apr 2022 18:20:47 +0200 Subject: [PATCH 14/14] Fix import order --- utils/cpp_exporter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/utils/cpp_exporter.py b/utils/cpp_exporter.py index d67d754bd..3a1cfa093 100644 --- a/utils/cpp_exporter.py +++ b/utils/cpp_exporter.py @@ -1,9 +1,8 @@ import os import re import shutil -from typing import List from pathlib import Path - +from typing import List import torch as th from gym import spaces