Skip to content

Commit

Permalink
[aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTMod…
Browse files Browse the repository at this point in the history
…uleBuilder to support Fields (#5120)

* [aot] [llvm] Implemented FieldCacheData and refactored initialize_llvm_runtime_snodes()

* Addressed compilation erros

* [aot] [llvm] LLVM AOT Field #1: Adjust serialization/deserialization logics for FieldCacheData

* [llvm] [aot] Added Field support for LLVM AOT

* [aot] [llvm] LLVM AOT Field #2: Updated LLVM AOTModuleLoader & AOTModuleBuilder to support Fields

* Fixed merge issues

* Stopped abusing Program*
  • Loading branch information
jim19930609 authored Jun 13, 2022
1 parent 6349d60 commit fae94a2
Show file tree
Hide file tree
Showing 13 changed files with 190 additions and 46 deletions.
5 changes: 5 additions & 0 deletions taichi/backends/cpu/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace lang {
namespace cpu {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
};
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/cpu/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}
};

} // namespace
Expand Down
5 changes: 5 additions & 0 deletions taichi/backends/cuda/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@ namespace lang {
namespace cuda {

class AotModuleBuilderImpl : public LlvmAotModuleBuilder {
public:
explicit AotModuleBuilderImpl(LlvmProgramImpl *prog)
: LlvmAotModuleBuilder(prog) {
}

private:
CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) override;
};
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/cuda/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class AotModuleImpl : public LlvmAotModule {
TI_NOT_IMPLEMENTED;
return nullptr;
}

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override {
TI_NOT_IMPLEMENTED;
return nullptr;
}
};

} // namespace
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ void SNode::set_snode_tree_id(int id) {
snode_tree_id_ = id;
}

int SNode::get_snode_tree_id() {
int SNode::get_snode_tree_id() const {
return snode_tree_id_;
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/snode.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class SNode {

void set_snode_tree_id(int id);

int get_snode_tree_id();
int get_snode_tree_id() const;

static void reset_counter() {
counter = 0;
Expand Down
33 changes: 33 additions & 0 deletions taichi/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <algorithm>
#include "taichi/llvm/launch_arg_info.h"
#include "taichi/llvm/llvm_program.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -34,5 +35,37 @@ void LlvmAotModuleBuilder::add_per_backend(const std::string &identifier,
cache_.kernels[identifier] = std::move(kcache);
}

void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) {
// Field refers to a leaf node(Place SNode) in a SNodeTree.
// It makes no sense to just serialize the leaf node or its corresponding
// branch. Instead, the minimal unit we have to serialize is the entire
// SNodeTree. Note that SNodeTree's uses snode_tree_id as its identifier,
// rather than the field's name. (multiple fields may end up referring to the
// same SNodeTree)

// 1. Find snode_tree_id
int snode_tree_id = rep_snode->get_snode_tree_id();

// 2. Fetch Cache from the Program
// Kernel compilation is not allowed until all the Fields are finalized,
// so we finished SNodeTree compilation during AOTModuleBuilder construction.
//
// By the time "add_field_per_backend()" is called,
// SNodeTrees should have already been finalized,
// with compiled info stored in LlvmProgramImpl::cache_data_.
TI_ASSERT(prog_ != nullptr);
LlvmOfflineCache::FieldCacheData field_cache =
prog_->get_cached_field(snode_tree_id);

// 3. Update AOT Cache
cache_.fields[snode_tree_id] = std::move(field_cache);
}

} // namespace lang
} // namespace taichi
12 changes: 12 additions & 0 deletions taichi/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,27 @@ namespace lang {

class LlvmAotModuleBuilder : public AotModuleBuilder {
public:
explicit LlvmAotModuleBuilder(LlvmProgramImpl *prog) : prog_(prog) {
}

void dump(const std::string &output_dir,
const std::string &filename) const override;

protected:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;
virtual CodeGenLLVM::CompiledData compile_kernel(Kernel *kernel) = 0;

void add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
DataType dt,
std::vector<int> shape,
int row_num,
int column_num) override;

private:
mutable LlvmOfflineCache cache_;
LlvmProgramImpl *prog_ = nullptr;
};

} // namespace lang
Expand Down
55 changes: 55 additions & 0 deletions taichi/llvm/llvm_aot_module_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ class KernelImpl : public aot::Kernel {
FunctionType fn_;
};

class FieldImpl : public aot::Field {
public:
explicit FieldImpl(const LlvmOfflineCache::FieldCacheData &field)
: field_(field) {
}

explicit FieldImpl(LlvmOfflineCache::FieldCacheData &&field)
: field_(std::move(field)) {
}

LlvmOfflineCache::FieldCacheData get_field() const {
return field_;
}

private:
LlvmOfflineCache::FieldCacheData field_;
};

} // namespace

LlvmOfflineCache::KernelCacheData LlvmAotModule::load_kernel_from_cache(
Expand All @@ -37,5 +55,42 @@ std::unique_ptr<aot::Kernel> LlvmAotModule::make_new_kernel(
return std::make_unique<KernelImpl>(fn);
}

std::unique_ptr<aot::Field> LlvmAotModule::make_new_field(
const std::string &name) {
// Check if "name" represents snode_tree_id.
// Avoid using std::atoi due to its poor error handling.
char *end;
int snode_tree_id = static_cast<int>(strtol(name.c_str(), &end, 10 /*base*/));

TI_ASSERT(end != name.c_str());
TI_ASSERT(*end == '\0');

// Load FieldCache
LlvmOfflineCache::FieldCacheData loaded;
auto ok = cache_reader_->get_field_cache(loaded, snode_tree_id);
TI_ERROR_IF(!ok, "Failed to load field with id={}", snode_tree_id);

return std::make_unique<FieldImpl>(std::move(loaded));
}

void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer) {
auto *llvm_aot_module = dynamic_cast<LlvmAotModule *>(aot_module);
auto *aot_field_impl = dynamic_cast<FieldImpl *>(aot_field);

TI_ASSERT(llvm_aot_module != nullptr);
TI_ASSERT(aot_field_impl != nullptr);

auto *llvm_prog = llvm_aot_module->get_program();
const auto &field_cache = aot_field_impl->get_field();

int snode_tree_id = field_cache.tree_id;
if (!llvm_aot_module->is_snode_tree_initialized(snode_tree_id)) {
llvm_prog->initialize_llvm_runtime_snodes(field_cache, result_buffer);
llvm_aot_module->set_initialized_snode_tree(snode_tree_id);
}
}

} // namespace lang
} // namespace taichi
21 changes: 21 additions & 0 deletions taichi/llvm/llvm_aot_module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
namespace taichi {
namespace lang {

TI_DLL_EXPORT void finalize_aot_field(aot::Module *aot_module,
aot::Field *aot_field,
uint64 *result_buffer);

class LlvmAotModule : public aot::Module {
public:
explicit LlvmAotModule(const std::string &module_path,
Expand All @@ -27,6 +31,18 @@ class LlvmAotModule : public aot::Module {
return 0;
}

LlvmProgramImpl *const get_program() {
return program_;
}

void set_initialized_snode_tree(int snode_tree_id) {
initialized_snode_tree_ids.insert(snode_tree_id);
}

bool is_snode_tree_initialized(int snode_tree_id) {
return initialized_snode_tree_ids.count(snode_tree_id);
}

protected:
virtual FunctionType convert_module_to_function(
const std::string &name,
Expand All @@ -38,8 +54,13 @@ class LlvmAotModule : public aot::Module {
std::unique_ptr<aot::Kernel> make_new_kernel(
const std::string &name) override;

std::unique_ptr<aot::Field> make_new_field(const std::string &name) override;

LlvmProgramImpl *const program_{nullptr};
std::unique_ptr<LlvmOfflineCacheFileReader> cache_reader_{nullptr};

// To prevent repeated SNodeTree initialization
std::unordered_set<int> initialized_snode_tree_ids;
};

} // namespace lang
Expand Down
2 changes: 1 addition & 1 deletion taichi/llvm/llvm_offline_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct LlvmOfflineCache {
std::unordered_map<std::string, KernelCacheData>
kernels; // key = kernel_name

TI_IO_DEF(kernels);
TI_IO_DEF(fields, kernels);
};

class LlvmOfflineCacheFileReader {
Expand Down
66 changes: 39 additions & 27 deletions taichi/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,37 +273,22 @@ std::unique_ptr<StructCompiler> LlvmProgramImpl::compile_snode_tree_types_impl(
}

void LlvmProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
compile_snode_tree_types_impl(tree);
}

static LlvmOfflineCache::FieldCacheData construct_filed_cache_data(
const SNodeTree &tree,
const StructCompiler &struct_compiler) {
LlvmOfflineCache::FieldCacheData ret;
ret.tree_id = tree.id();
ret.root_id = tree.root()->id;
ret.root_size = struct_compiler.root_size;

const auto &snodes = struct_compiler.snodes;
for (size_t i = 0; i < snodes.size(); i++) {
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
snode_cache_data.id = snodes[i]->id;
snode_cache_data.type = snodes[i]->type;
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
snode_cache_data.chunk_size = snodes[i]->chunk_size;

ret.snode_metas.emplace_back(std::move(snode_cache_data));
}
auto struct_compiler = compile_snode_tree_types_impl(tree);
int snode_tree_id = tree->id();
int root_id = tree->root()->id;

return ret;
// Add compiled result to Cache
cache_field(snode_tree_id, root_id, *struct_compiler);
}

void LlvmProgramImpl::materialize_snode_tree(SNodeTree *tree,
uint64 *result_buffer) {
auto struct_compiler = compile_snode_tree_types_impl(tree);
compile_snode_tree_types(tree);
int snode_tree_id = tree->id();

auto field_cache_data = construct_filed_cache_data(*tree, *struct_compiler);
initialize_llvm_runtime_snodes(field_cache_data, result_buffer);
TI_ASSERT(cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end());
initialize_llvm_runtime_snodes(cache_data_.fields.at(snode_tree_id),
result_buffer);
}

uint64 LlvmProgramImpl::fetch_result_uint64(int i, uint64 *result_buffer) {
Expand Down Expand Up @@ -365,12 +350,12 @@ void LlvmProgramImpl::print_list_manager_info(void *list_manager,

std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder() {
if (config->arch == Arch::x64 || config->arch == Arch::arm64) {
return std::make_unique<cpu::AotModuleBuilderImpl>();
return std::make_unique<cpu::AotModuleBuilderImpl>(this);
}

#if defined(TI_WITH_CUDA)
if (config->arch == Arch::cuda) {
return std::make_unique<cuda::AotModuleBuilderImpl>();
return std::make_unique<cuda::AotModuleBuilderImpl>(this);
}
#endif

Expand Down Expand Up @@ -701,6 +686,33 @@ void LlvmProgramImpl::cache_kernel(
kernel_cache.offloaded_task_list = std::move(offloaded_task_list);
}

void LlvmProgramImpl::cache_field(int snode_tree_id,
int root_id,
const StructCompiler &struct_compiler) {
if (cache_data_.fields.find(snode_tree_id) != cache_data_.fields.end()) {
// [TODO] check and update the Cache, instead of simply return.
return;
}

LlvmOfflineCache::FieldCacheData ret;
ret.tree_id = snode_tree_id;
ret.root_id = root_id;
ret.root_size = struct_compiler.root_size;

const auto &snodes = struct_compiler.snodes;
for (size_t i = 0; i < snodes.size(); i++) {
LlvmOfflineCache::FieldCacheData::SNodeCacheData snode_cache_data;
snode_cache_data.id = snodes[i]->id;
snode_cache_data.type = snodes[i]->type;
snode_cache_data.cell_size_bytes = snodes[i]->cell_size_bytes;
snode_cache_data.chunk_size = snodes[i]->chunk_size;

ret.snode_metas.emplace_back(std::move(snode_cache_data));
}

cache_data_.fields[snode_tree_id] = std::move(ret);
}

void LlvmProgramImpl::dump_cache_data_to_disk() {
if (config->offline_cache && !cache_data_.kernels.empty()) {
LlvmOfflineCacheFileWriter writer{};
Expand Down
23 changes: 17 additions & 6 deletions taichi/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,34 @@ class LlvmProgramImpl : public ProgramImpl {
std::vector<LlvmOfflineCache::OffloadedTaskCacheData>
&&offloaded_task_list);

void cache_field(int snode_tree_id,
int root_id,
const StructCompiler &struct_compiler);

LlvmOfflineCache::FieldCacheData get_cached_field(int snode_tree_id) const {
TI_ASSERT(cache_data_.fields.find(snode_tree_id) !=
cache_data_.fields.end());
return cache_data_.fields.at(snode_tree_id);
}

Device *get_compute_device() override {
return device_.get();
}

/**
* Initializes the SNodes for LLVM based backends.
*/
void initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer);

private:
std::unique_ptr<llvm::Module> clone_struct_compiler_initial_context(
bool has_multiple_snode_trees,
TaichiLLVMContext *tlctx);

std::unique_ptr<StructCompiler> compile_snode_tree_types_impl(
SNodeTree *tree);
/**
* Initializes the SNodes for LLVM based backends.
*/
void initialize_llvm_runtime_snodes(
const LlvmOfflineCache::FieldCacheData &field_cache_data,
uint64 *result_buffer);

uint64 fetch_result_uint64(int i, uint64 *result_buffer);

Expand Down

0 comments on commit fae94a2

Please sign in to comment.