Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix linalg dialect bindings are incomplete #54 #55

Merged
merged 6 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions bindings/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,85 +7,90 @@ import BinaryBuilderBase: PkgSpec, Prefix, temp_prefix, setup_dependencies, clea
using Clang.Generators

function mlir_dialects(version::VersionNumber)
dialects = Tuple{String,String}[
("Builtin.jl", "../IR/BuiltinOps.td"),
# construct set of dialects to generate bindings for
# 1. dialect name
# 2. bindings file name
# 3. tablegen files
dialects = Tuple{String,String,Vector{String}}[
("builtin", "Builtin.jl", ["../IR/BuiltinOps.td"]),
]

if version >= v"14"
append!(dialects, [
("AMX.jl", "AMX/AMX.td"),
("Affine.jl", "Affine/IR/AffineOps.td"),
("ArmNeon.jl", "ArmNeon/ArmNeon.td"),
("ArmSVE.jl", "ArmSVE/ArmSVE.td"),
("Async.jl", "Async/IR/AsyncOps.td"),
("Bufferization.jl", "Bufferization/IR/BufferizationOps.td"),
("Complex.jl", "Complex/IR/ComplexOps.td"),
# ("DLTI.jl", "DLTI/DLTI.td"), fails on v15
("EmitC.jl", "EmitC/IR/EmitC.td"),
("LLVMIR.jl", "LLVMIR/LLVMOps.td"),
("Linalg.jl", "Linalg/IR/LinalgOps.td"), # TODO include LinalgStructuredOps.td
("Math.jl", "Math/IR/MathOps.td"),
("MemRef.jl", "MemRef/IR/MemRefOps.td"),
("OpenACC.jl", "OpenACC/OpenACCOps.td"),
("OpenMP.jl", "OpenMP/OpenMPOps.td"),
("PDL.jl", "PDL/IR/PDLOps.td"),
("PDLInterp.jl", "PDLInterp/IR/PDLInterpOps.td"),
("Quant.jl", "Quant/QuantOps.td"),
("SPIRV.jl", "SPIRV/IR/SPIRVOps.td"),
("Shape.jl", "Shape/IR/ShapeOps.td"),
("SparseTensor.jl", "SparseTensor/IR/SparseTensorOps.td"),
("Tensor.jl", "Tensor/IR/TensorOps.td"),
("Tosa.jl", "Tosa/IR/TosaOps.td"),
("Vector.jl", "Vector/IR/VectorOps.td"),
("X86Vector.jl", "X86Vector/X86Vector.td"),
("amx", "AMX.jl", ["AMX/AMX.td"]),
("affine", "Affine.jl", ["Affine/IR/AffineOps.td"]),
("arm_neon", "ArmNeon.jl", ["ArmNeon/ArmNeon.td"]),
("arm_sve", "ArmSVE.jl", ["ArmSVE/ArmSVE.td"]),
("async", "Async.jl", ["Async/IR/AsyncOps.td"]),
("bufferization", "Bufferization.jl", ["Bufferization/IR/BufferizationOps.td"]),
("complex", "Complex.jl", ["Complex/IR/ComplexOps.td"]),
# ("dlti", "DLTI.jl"[, "DLTI/DLTI.td"]), fails on v15
("emitc", "EmitC.jl", ["EmitC/IR/EmitC.td"]),
("llvm", "LLVMIR.jl", ["LLVMIR/LLVMOps.td"]),
("linalg", "Linalg.jl", ["Linalg/IR/LinalgOps.td", "Linalg/IR/LinalgStructuredOps.td"]),
("math", "Math.jl", ["Math/IR/MathOps.td"]),
("memref", "MemRef.jl", ["MemRef/IR/MemRefOps.td"]),
("acc", "OpenACC.jl", ["OpenACC/OpenACCOps.td"]),
("omp", "OpenMP.jl", ["OpenMP/OpenMPOps.td"]),
("pdl", "PDL.jl", ["PDL/IR/PDLOps.td"]),
("pdl_interp", "PDLInterp.jl", ["PDLInterp/IR/PDLInterpOps.td"]),
("quant", "Quant.jl", ["Quant/QuantOps.td"]),
("shape", "Shape.jl", ["Shape/IR/ShapeOps.td"]),
("sparse_tensor", "SparseTensor.jl", ["SparseTensor/IR/SparseTensorOps.td"]),
("tensor", "Tensor.jl", ["Tensor/IR/TensorOps.td"]),
("tosa", "Tosa.jl", ["Tosa/IR/TosaOps.td"]),
("vector", "Vector.jl", ["Vector/IR/VectorOps.td"]),
("x86vector", "X86Vector.jl", ["X86Vector/X86Vector.td"]),
])
end

if v"14" <= version < v"15"
append!(dialects, [
("GPU.jl", "GPU/GPUOps.td"),
("SCF.jl", "SCF/SCFOps.td"),
("StandardOps.jl", "StandardOps/IR/Ops.td"),
("gpu", "GPU.jl", ["GPU/GPUOps.td"]), # moved to IR subfolder in v15
("scf", "SCF.jl", ["SCF/SCFOps.td"]), # moved to IR subfolder in v15
("std", "StandardOps.jl", ["StandardOps/IR/Ops.td"]),
])
end

if v"14" <= version < v"16"
append!(dialects, [
("Arithmetic.jl", "Arithmetic/IR/ArithmeticOps.td"), # renamed to 'Arith' in v16
("arith", "Arithmetic.jl", ["Arithmetic/IR/ArithmeticOps.td"]), # folder renamed to 'Arith' in v16
("spv", "SPIRV.jl", ["SPIRV/IR/SPIRVOps.td"]), # dialect name renamed to 'spirv' in v16
])
end

if version >= v"15"
append!(dialects, [
("GPU.jl", "GPU/IR/GPUOps.td"),
("SCF.jl", "SCF/IR/SCFOps.td"),
("AMDGPU.jl", "AMDGPU/AMDGPU.td"),
("ControlFlow.jl", "ControlFlow/IR/ControlFlowOps.td"),
("Func.jl", "Func/IR/FuncOps.td"),
("MLProgram.jl", "MLProgram/IR/MLProgramOps.td"),
("NVGPU.jl", "NVGPU/IR/NVGPU.td"),
("Transform.jl", "Transform/IR/TransformOps.td"),
("gpu", "GPU.jl", ["GPU/IR/GPUOps.td"]),
("scf", "SCF.jl", ["SCF/IR/SCFOps.td"]),
("amdgpu", "AMDGPU.jl", ["AMDGPU/AMDGPU.td"]),
("cf", "ControlFlow.jl", ["ControlFlow/IR/ControlFlowOps.td"]),
("func", "Func.jl", ["Func/IR/FuncOps.td"]),
("ml_program", "MLProgram.jl", ["MLProgram/IR/MLProgramOps.td"]),
("nvgpu", "NVGPU.jl", ["NVGPU/IR/NVGPU.td"]),
("transform", "Transform.jl", ["Transform/IR/TransformOps.td"]),
])
end

if version >= v"16"
append!(dialects, [
("Arith.jl", "Arith/IR/ArithOps.td"),
("Index.jl", "Index/IR/IndexOps.td"),
("arith", "Arith.jl", ["Arith/IR/ArithOps.td"]),
("index", "Index.jl", ["Index/IR/IndexOps.td"]),
("spirv", "SPIRV.jl", ["SPIRV/IR/SPIRVOps.td"]),
])
end

if version >= v"17"
append!(dialects, [
("ArmSME.jl", "ArmSME/IR/ArmSME.td"),
("IRDL.jl", "IRDL/IR/IRDLOps.td"),
("UB.jl", "UB/IR/UBOps.td"),
("arm_sme", "ArmSME.jl", ["ArmSME/IR/ArmSME.td"]),
("irdl", "IRDL.jl", ["IRDL/IR/IRDLOps.td"]),
("ub", "UB.jl", ["UB/IR/UBOps.td"]),
])
end

if version >= v"18"
append!(dialects, [
("Mesh.jl", "Mesh/IR/MeshOps.td"),
("mesh", "Mesh.jl", ["Mesh/IR/MeshOps.td"]),
])
end

Expand All @@ -112,7 +117,7 @@ for (julia_version, llvm_version) in julia_llvm
dependencies = PkgSpec[
PkgSpec(; name="LLVM_full_jll", version=llvm_version),
PkgSpec(; name="mlir_jl_tblgen_jll")
]
]

artifact_paths = setup_dependencies(prefix, dependencies, platform; verbose=true)

Expand Down Expand Up @@ -146,15 +151,30 @@ for (julia_version, llvm_version) in julia_llvm
# generate MLIR dialect bindings
mkpath(joinpath(@__DIR__, "..", "src", "Dialects", string(llvm_version.major)))

for (binding, td) in mlir_dialects(llvm_version)
flags = [
"--generator=jl-op-defs",
joinpath(include_dir, "mlir", "Dialect", td),
"-I", include_dir,
"-o", joinpath(@__DIR__, "..", "src", "Dialects", string(llvm_version.major), binding),
]
run(`$mlir_jl_tblgen $flags`)
println("- Generated \"$binding\" from \"$td\"")
for (dialect_name, binding, tds) in mlir_dialects(llvm_version)
tempfiles = map(tds) do td
tempfile, _ = mktemp()
flags = [
"--generator=jl-op-defs",
"--disable-module-wrap",
joinpath(include_dir, "mlir", "Dialect", td),
"-I", include_dir,
"-o", tempfile,
]
run(`$mlir_jl_tblgen $flags`)
return tempfile
end

output = joinpath(@__DIR__, "..", "src", "Dialects", string(llvm_version.major), binding)
open(output, "w") do io
println(io, "module $dialect_name")
for tempfile in tempfiles
println(io, read(tempfile, String))
end
println(io, "end")
end

println("- Generated \"$binding\" from $(join(tds, ",", " and "))")
end
end
end
26 changes: 24 additions & 2 deletions deps/tblgen/jl-generators.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ namespace

} // namespace

extern bool disableModuleWrap;

bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper,
llvm::raw_ostream &os)
{
Expand All @@ -162,7 +164,19 @@ bool emitOpTableDefs(const llvm::RecordKeeper &recordKeeper,
std::vector<llvm::Record *> opdefs = recordKeeper.getAllDerivedDefinitions("Op");
#endif

const char *moduleTemplate = R"(module {0}
const char *moduleTemplate;
if (disableModuleWrap)
{
moduleTemplate = R"(import ...IR: NamedAttribute, MLIRType, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
import ..Dialects: namedattribute, operandsegmentsizes
import ...API

{0}
)";
}
else
{
moduleTemplate = R"(module {0}

import ...IR: NamedAttribute, MLIRType, Value, Location, Block, Region, Attribute, create_operation, context, IndexType
import ..Dialects: namedattribute, operandsegmentsizes
Expand All @@ -171,6 +185,7 @@ import ...API
{1}
end # {0}
)";
}

const char *functiontemplate = R"(
{3}
Expand Down Expand Up @@ -418,7 +433,14 @@ end
modulecontents += llvm::formatv(functiontemplate, functionname, arguments, functionbody, description);
}

os << llvm::formatv(moduleTemplate, modulename, modulecontents);
if (disableModuleWrap)
{
os << llvm::formatv(moduleTemplate, modulecontents);
}
else
{
os << llvm::formatv(moduleTemplate, modulename, modulecontents);
}

return false;
}
3 changes: 3 additions & 0 deletions deps/tblgen/mlir-jl-tblgen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ static std::array<GeneratorInfo, 1> generators {{
}};

generator_function* generator;
bool disableModuleWrap;

int main(int argc, char **argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::opt<std::string> generatorOpt("generator", llvm::cl::desc("Generator to run"), cl::Required);
llvm::cl::opt<bool> disableModuleWrapOpt("disable-module-wrap", llvm::cl::desc("Disable module wrap"), cl::init(false));
cl::ParseCommandLineOptions(argc, argv);
for (const auto& spec : generators) {
if (generatorOpt == spec.name) {
Expand All @@ -57,6 +59,7 @@ int main(int argc, char **argv) {
llvm::errs() << "Invalid generator type\n";
abort();
}
disableModuleWrap = disableModuleWrapOpt;

return TableGenMain(argv[0], [](raw_ostream& os, RecordKeeper &records) {
return generator(records, os);
Expand Down