diff --git a/apps/benchmark/adreno/adreno_gpu_bench_collage.py b/apps/benchmark/adreno/adreno_gpu_bench_collage.py new file mode 100755 index 000000000000..9f09253a84a5 --- /dev/null +++ b/apps/benchmark/adreno/adreno_gpu_bench_collage.py @@ -0,0 +1,384 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Compares Collage with various other baselines.""" +import argparse +import tvm +from tvm import relay +import logging +import os +import sys +import numpy as np +from tvm.relay import testing +from tvm.contrib.utils import tempdir +from tvm import rpc +from tvm.relay.build_module import bind_params_by_name +from tvm import autotvm +from tvm.runtime.vm import VirtualMachine +import tvm.contrib.graph_executor as runtime +from tvm.contrib import utils, ndk +from tvm.relay.collage.collage import * +from tvm.relay.op.contrib import clml + +logging.basicConfig(level=logging.INFO) + + +### +### How aggressively to look for candidates? +### +TVM_MAX_DEPTH = 8 +BYOC_MAX_DEPTH = 8 + +## +## Default config definition +## +HOST = tvm.target.Target("llvm -mtriple=arm64-linux-android") +OPENCL = tvm.target.Target("opencl -device=adreno", HOST) +NDK_CC = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") + + +def print_progress(msg): + """print progress message + + Parameters + ---------- + msg: str + The message to print + """ + sys.stdout.write(msg + "\r") + sys.stdout.flush() + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=1024, + early_stopping=None, + log_filename="tuning.log", +): + from tvm.autotvm.tuner import XGBTuner + + tmp_log_file = log_filename + ".tmp" + + for i, tsk in enumerate(reversed(tasks)): + print("Task: ", tsk) + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + + # create tuner + if tuner == "xgb": + tuner_obj = XGBTuner(tsk, loss_type="reg") + elif tuner == "xgb_knob": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob") + elif tuner == "xgb_itervar": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar") + elif tuner == "xgb_curve": + tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve") + elif tuner == "xgb_rank": + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "xgb_rank_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob") + elif tuner == "xgb_rank_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar") + elif tuner == "xgb_rank_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve") + elif tuner == "xgb_rank_binary": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary") + elif tuner == "xgb_rank_binary_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob") + elif tuner == "xgb_rank_binary_itervar": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar") + elif tuner == "xgb_rank_binary_curve": + tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + autotvm.callback.log_to_file(tmp_log_file), + ], + ) + + autotvm.record.pick_best(tmp_log_file, log_filename) + + +########### Collage Drivers ########### + + +def compile_and_run(label, model, targets, inputs): + """Compile model for target and run it with profiling.""" + logging.info(f"Compiling {model['name']} using {label} with {targets}...") + mod = model["mod"] + exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"]) + lib = exe.mod + temp = utils.tempdir() + dso_binary = "dev_lib_cl.so" + dso_binary_path = temp.relpath(dso_binary) + logging.info(f"Exporting library to {dso_binary_path}...") + lib.export_library(dso_binary_path, cc=NDK_CC) + tracker = rpc.connect_tracker(args.host, args.port) + remote = tracker.request(args.rpc_key, priority=0, session_timeout=600) + ctx = remote.cl(0) + remote.upload(dso_binary_path) + rlib = remote.load_module(dso_binary) + vm_factory = tvm.runtime.vm.VirtualMachine(rlib, ctx, "naive") + func_name = "main" + main_args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod[func_name].params} + profile = vm_factory.benchmark( + ctx, repeat=5, number=20, min_repeat_ms=0, func_name=func_name, **main_args + ) + return profile.mean + + +def collage(model, input_data, tune_log=""): + """Run the Collage partitioner for a set of Opencl Adreno related targets and profile the result""" + logging.info(f"collage | {model['name']}") + logging.info("-------------- BEGIN ORIGINAL --------------") + logging.info(model["mod"]) + logging.info("-------------- END ORIGINAL ----------------") + with autotvm.apply_history_best(tune_log): + targets = [] + targets.append(OPENCL) + use_fp16 = model["main_dtype"] == "float16" + targets.append(tvm.target.Target("clml", HOST)) + + # Register byoc fusion style for compiler with available + # options [compiler.NoFusion | compiler.TVMFusion | compiler.MaxDepthFusion] + config = { + "relay.collage.tvm_max_depth": TVM_MAX_DEPTH, + "relay.collage.byoc_max_depth": BYOC_MAX_DEPTH, + "relay.collage.byoc_fusion_style": ["clml.NoFusion"], + } + logging.info(f"Using PassContext(config={config}") + ctxt = tvm.transform.PassContext(config=config) + config = tvm.target.make_compilation_config(ctxt, targets) + with ctxt: + mod = model["mod"] + """Collage partition with tvm opencl and clml target on rpc device""" + mod = tvm.relay.transform.CollagePartition( + config, + cost_estimator=CostEstimator( + host=args.host, port=args.port, rpc_key=args.rpc_key, ndk_cc=NDK_CC + ), + )(mod) + partitioned_model = model.copy() + partitioned_model["mod"] = mod + logging.info("-------------- BEGIN PARTITIONED --------------") + logging.info(partitioned_model["mod"]) + logging.info("-------------- END PARTITIONED ----------------") + return compile_and_run("collage", partitioned_model, targets, input_data) + + +def just_clml(model, input_data, tune_log=""): + """Run partition_for_clml, complete the compilation with TVM, and profile the result.""" + logging.info(f"just_clml | {model['name']}") + logging.info("-------------- BEGIN ORIGINAL --------------") + logging.info(model["mod"]) + logging.info("-------------- END ORIGINAL ----------------") + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + logging.info("Partitioning for CLML...") + mod = tvm.relay.op.contrib.clml.partition_for_clml(model["mod"], model["params"]) + partitioned_model = model.copy() + partitioned_model["mod"] = mod + logging.info("-------------- BEGIN PARTITIONED --------------") + logging.info(partitioned_model["mod"]) + logging.info("-------------- END PARTITIONED ----------------") + targets = [] + targets.append(OPENCL) + targets.append(tvm.target.Target("clml", HOST)) + return compile_and_run("just_clml", partitioned_model, OPENCL, input_data) + + +def just_tvm(model, input_data, tune_log=""): + """Compile and profile using vanilla TVM.""" + logging.info(f"just_tvm | {model['name']}") + logging.info("-------------- BEGIN ORIGINAL --------------") + logging.info(model["mod"]) + logging.info("-------------- END ORIGINAL ----------------") + with autotvm.apply_history_best(tune_log): + with tvm.transform.PassContext(opt_level=3): + return compile_and_run("just_tvm", model, OPENCL, input_data) + + +def get_model(model_name, dtype): + + if "mobilenet" in model_name: + mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=dtype) + elif "resnet" in model_name: + n_layer = int(model_name.split("-")[1]) + mod, params = testing.resnet.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype) + elif model_name == "inception_v3": + input_shape = (1, 3, 299, 299) + mod, params = testing.inception_v3.get_workload(batch_size=1, dtype=dtype) + elif "vgg" in model_name: + n_layer = int(model_name.split("-")[1]) + mod, params = testing.vgg.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype) + elif "densenet" in model_name: + n_layer = int(model_name.split("-")[1]) + mod, params = testing.densenet.get_workload( + densenet_size=n_layer, batch_size=1, dtype=dtype + ) + elif "squeezenet" in model_name: + version = model_name.split("_v")[1] + mod, params = testing.squeezenet.get_workload(batch_size=1, version=version, dtype=dtype) + + initializer = tvm.relay.testing.init.Xavier() + for param_name in list(params.keys()): + filter_data = np.zeros(params[param_name].shape).astype(params[param_name].dtype) + if len(filter_data.shape) > 1: + initializer("weight", filter_data) + else: + initializer("bias", filter_data) + params[param_name] = tvm.nd.array(filter_data) + + if params: + mod["main"] = bind_params_by_name(mod["main"], params) + mod = tvm.relay.transform.FoldConstant()(mod) + return { + "name": model_name, + "input_shapes": {"data": [1, 3, 224, 224]}, + "input_dtypes": {"data": dtype}, + "mod": mod, + "params": params, + "main_dtype": dtype, + } + + +########### Runners ########### +def evaluate_network(model_name, dtype): + print("Network evaluating .. " + model_name + " " + dtype) + np.random.seed(0) + model = get_model(model_name, dtype) + tune_log = "adreno_v0.01.log" + if args.tune: + # Auto Tuning + tune_log = "adreno-" + model_name + "-" + dtype + ".log" + tuning_options = { + "log_filename": tune_log, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), + runner=autotvm.RPCRunner( + args.rpc_key, + host=args.host, + port=args.port, + number=3, + timeout=600, + ), + ), + } + tasks = autotvm.task.extract_from_program( + net, target=OPENCL, target_host=HOST, params=params + ) + tune_tasks(tasks, **tuning_options) + + print_progress("%-20s building..." % network) + input_data = {} + for name, shape in model["input_shapes"].items(): + input_data[name] = np.random.uniform(-1.0, 1.0, shape).astype(model["input_dtypes"][name]) + clml_time = just_clml(model, input_data, tune_log) + tvm_time = just_tvm(model, input_data, tune_log) + + """Run Collage for tvm and clml compiler target.""" + collage_time = collage(model, input_data, tune_log) + return (tvm_time, clml_time, collage_time) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--network", + type=str, + choices=[ + "resnet-18", + "resnet-34", + "resnet-50", + "vgg-16", + "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ], + help="The name of neural network", + ) + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-key", type=str, default="android") + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16"], + help="The data type of neural network", + ) + parser.add_argument("--tune", type=bool, default=False) + args = parser.parse_args() + + if args.network is None: + networks = [ + "resnet-18", + "resnet-34", + "resnet-50", + # "vgg-16", + # "vgg-19", + "densenet-121", + "inception_v3", + "mobilenet", + "squeezenet_v1.0", + "squeezenet_v1.1", + ] + else: + networks = [args.network] + + if args.dtype is None: + dtypes = ["float32", "float16"] + else: + dtypes = [args.dtype] + + results = {} + net_results = [] + for network in networks: + for dtype in dtypes: + ftime = evaluate_network(network, dtype) + results[network + "-" + dtype] = ftime + # net_results.append([network + "-" + dtype] + list(ftime)) + # np.savetxt("results.txt", np.array(net_results), fmt="%s") + + print("----------------------------------------------------------------------") + print( + "%-30s %-20s %-20s %-20s" + % ("Network Name", "TVM Opencl Time", "CLML Time", "Collage - TVM/CLML Time") + ) + print("----------------------------------------------------------------------") + for key, val in results.items(): + print( + "%-30s %-20s %-20s %-20s" + % (key, "%.2f ms" % val[0], "%.2f ms" % val[1], "%.2f ms" % val[2]) + ) diff --git a/python/tvm/relay/collage/collage.py b/python/tvm/relay/collage/collage.py index cfc527c2b977..ab3b57887a38 100644 --- a/python/tvm/relay/collage/collage.py +++ b/python/tvm/relay/collage/collage.py @@ -28,19 +28,30 @@ import tvm from tvm._ffi.registry import register_func, register_object from tvm.runtime import Object +from tvm import rpc from . import _ffi_api # Parameters to use when estimating latency (of both partitions and overall models). MEASURE_NUMBER = 20 MEASURE_REPEAT = 5 -WARMUP_MIN_REPEAT_MS = 250 +WARMUP_MIN_REPEAT_MS = 10 @register_object("relay.collage.CostEstimator") class CostEstimator(Object): """CostEstimator class""" - def __init__(self): + TRACKER_HOST = None + TRACKER_PORT = None + DEVICE_KEY = None + NDK_CC = None + + def __init__(self, host=None, port=None, rpc_key="", ndk_cc="nvcc"): + """RPC device config settings""" + CostEstimator.TRACKER_HOST = host + CostEstimator.TRACKER_PORT = port + CostEstimator.DEVICE_KEY = rpc_key + CostEstimator.NDK_CC = ndk_cc self.__init_handle_by_constructor__(_ffi_api.CostEstimator) @@ -105,15 +116,22 @@ def estimate_seconds(mod, target): # Finalize compilation tmp_dir = tempfile.mkdtemp() - code, lib = exe.save() + lib = exe.mod lib_path = os.path.join(tmp_dir, "library.so") - # TODO(mbs): Avoid nvcc dependency? - lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc") - lib = tvm.runtime.load_module(lib_path) - exe = tvm.runtime.vm.Executable.load_exec(code, lib) + + lib.export_library(lib_path, workspace_dir=tmp_dir, cc=CostEstimator.NDK_CC) + if CostEstimator.TRACKER_PORT: + # Prepare the lib for RPC device + tracker = rpc.connect_tracker(CostEstimator.TRACKER_HOST, CostEstimator.TRACKER_PORT) + remote = tracker.request(CostEstimator.DEVICE_KEY, priority=0, session_timeout=600) + device = remote.cl(0) + remote.upload(lib_path, target="library.so") + lib = remote.load_module("library.so") + else: + lib = tvm.runtime.load_module(lib_path) # Benchmark the module. - the_vm = tvm.runtime.vm.VirtualMachine(exe, device) + the_vm = tvm.runtime.vm.VirtualMachine(lib, device) func_name = "main" main_args = {v.name_hint: arg_for(v.checked_type, device) for v in mod[func_name].params} logging.info("Benchmarking module to estimate") diff --git a/src/relay/collage/collage_partitioner.cc b/src/relay/collage/collage_partitioner.cc index 54fc6c45ca70..e0da77f91928 100644 --- a/src/relay/collage/collage_partitioner.cc +++ b/src/relay/collage/collage_partitioner.cc @@ -319,6 +319,9 @@ transform::Pass CollagePartition(CompilationConfig config, CostEstimator cost_es IRModule mod, transform::PassContext ctxt) { VLOG(1) << "CollagePartition input:" << std::endl << PrettyPrint(mod); + // Applying span indexes to graph + mod = transform::CapturePostDfsIndexInSpans()(mod); + Array partition_specs = GatherPartitionSpecs(config); VLOG(1) << "Gathered " << partition_specs.size() << " partition specs"; diff --git a/src/relay/collage/sub_graph.cc b/src/relay/collage/sub_graph.cc index a6559ff5fdb5..9d4ba48f5be1 100644 --- a/src/relay/collage/sub_graph.cc +++ b/src/relay/collage/sub_graph.cc @@ -130,6 +130,21 @@ class Extractor : public ExprMutator { // Sweep backwards through the body, rewriting to account for each nested sub-graph. body = NestedSubGraph::ParallelRewrite(body_dataflow_graph, body, std::move(nested_sub_graphs)); + // Invoke the compiler target preprocessing function define under "relay.ext.compiler.optimize" + if (opt_attrs_.defined() && (opt_attrs_.find("Compiler") != opt_attrs_.end())) { + DictAttrs opt_dict_attr = DictAttrs(opt_attrs_); + std::string spec_name = opt_dict_attr.GetAttr("Compiler", Optional()).value(); + std::string ext_opt = "relay.ext." + spec_name + ".optimize"; + auto pf = tvm::runtime::Registry::Get(ext_opt); + if (pf != nullptr) { + auto mod = IRModule::FromExpr(body); + mod = transform::InferType()(mod); + mod = (*pf)(mod); + mod = transform::InferType()(mod); + body = Downcast(mod->Lookup("main"))->body; + } + } + if (for_function) { // Rewrite so all input nodes are now conveyed via call arguments to a new function. Array arg_types; @@ -245,7 +260,7 @@ class Extractor : public ExprMutator { } else if (CanInline(expr)) { // Implicitly include inlinable input sub-expressions. return expr; - } else if (opt_attrs_.defined()) { + } else if (opt_attrs_.defined() && (expr.as() == nullptr)) { // Map to a function parameter. return VarFor(expr); } else { diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py index b8ce236cdda9..26fc112fa9ed 100644 --- a/tests/python/contrib/test_clml/infrastructure.py +++ b/tests/python/contrib/test_clml/infrastructure.py @@ -48,8 +48,7 @@ from tvm.runtime.vm import VirtualMachine import json - -NDK_CROSS_COMPILER = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") +from tvm.relay.collage.collage import * def get_cpu_op_count(mod): @@ -172,7 +171,8 @@ def build_and_run_vm( dso_binary = "dev_lib_cl.so" dso_binary_path = temp.relpath(dso_binary) dev = remote.cl(0) - vmc.mod.export_library(dso_binary_path, cc=NDK_CROSS_COMPILER) + ndk_cc = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") + vmc.mod.export_library(dso_binary_path, cc=ndk_cc) remote.upload(dso_binary_path) rlib = remote.load_module(dso_binary) vm = VirtualMachine(rlib, dev, "naive") @@ -248,3 +248,33 @@ def verify_codegen( f"Actual={codegen_str} \n" f"Expected={known_good_codegen_str}" ) + + +########### Collage Drivers ########### + + +def compile_and_run(remote, label, model, targets, inputs): + """Compile model for target and run it with profiling.""" + logging.info(f"Compiling {model['name']} using {label} with {targets}...") + mod = model["mod"] + exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"]) + lib = exe.mod + temp = utils.tempdir() + dso_binary = "dev_lib_cl.so" + dso_binary_path = temp.relpath(dso_binary) + logging.info(f"Exporting library to {dso_binary_path}...") + ndk_cc = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") + lib.export_library(dso_binary_path, cc=ndk_cc) + ctx = remote.cl(0) + remote.upload(dso_binary_path) + rlib = remote.load_module(dso_binary) + vm_factory = tvm.runtime.vm.VirtualMachine(rlib, ctx, "naive") + inputs_data = {} + for key in inputs.keys(): + inputs_data[key] = tvm.nd.array(inputs[key], ctx) + for k, v in model["params"].items(): + inputs_data[k] = tvm.nd.array(v, ctx) + vm_factory.set_input("main", **inputs_data) + vm_factory.invoke_stateful("main") + out = vm_factory.get_outputs()[0] + return out.asnumpy() diff --git a/tests/python/contrib/test_clml/test_adreno_collage_targets.py b/tests/python/contrib/test_clml/test_adreno_collage_targets.py index 4cf86a0e058d..9316b0085b73 100644 --- a/tests/python/contrib/test_clml/test_adreno_collage_targets.py +++ b/tests/python/contrib/test_clml/test_adreno_collage_targets.py @@ -19,40 +19,18 @@ import tvm import logging -import tempfile import os -import shutil import numpy as np from tvm.relay import testing -from tvm import rpc -from tvm.contrib import utils, ndk from tvm.relay.build_module import bind_params_by_name - -# The following are necessary to force global functions or pattern tables to be registered +from tvm import autotvm from tvm.relay.collage.collage import * -from tvm.relay.op.contrib import clml +from test_clml.infrastructure import compile_and_run import pytest logging.basicConfig(level=logging.INFO) -########### Configuration ########### - -### -### TVM Opencl AutoTvm log file name -### -TUNING_LOG = "" - -### -### If true, run all models -### -ALL_MODELS = False - -### -### If true, run all configurations -### -ALL_CONFIGS = False - ### ### How aggressively to look for candidates? ### @@ -60,194 +38,38 @@ BYOC_MAX_DEPTH = 8 ### -### AutoTVM tuning parameters. +### TVM Opencl AutoTvm log file name ### -AUTOTVM_NUM_TRIALS = 1024 -AUTOTVM_EARLY_STOPPING = 600 -TIMEOUT = 10 -MEASURE_NUMBER = tvm.relay.collage.MEASURE_NUMBER -MEASURE_REPEAT = tvm.relay.collage.MEASURE_REPEAT -WARMUP_MIN_REPEAT_MS = tvm.relay.collage.WARMUP_MIN_REPEAT_MS +TUNING_LOG = "" ## -## RPC Build configuration +## Default Target definition ## HOST = tvm.target.Target("llvm -mtriple=arm64-linux-android") -OPENCL = tvm.target.Target("opencl", HOST) +OPENCL = tvm.target.Target("opencl -device=adreno", HOST) RPC_TRACKER_HOST = os.getenv("TVM_TRACKER_HOST", "localhost") RPC_TRACKER_PORT = int(os.getenv("TVM_TRACKER_PORT", 9090)) RPC_KEY = os.getenv("RPC_DEVICE_KEY", "android") -NDK_CROSS_COMPILER = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") - - -########### AutoTVM tuning helpers ########### - - -def extract_autotvm_tasks(mod, target): - """Returns TVM kernels to tune for mod and target.""" - return tvm.autotvm.task.extract_from_program(mod, target=target, params=None) - - -def optional_tuning_records(log_filename): - """Returns existing tuning records, if any.""" - if log_filename == "" or not os.path.exists(log_filename): - return tvm.autotvm.task.FallbackContext() - else: - return tvm.autotvm.task.ApplyHistoryBest(log_filename) - - -def is_already_tuned(task, log_filename): - """Returns True if we already have a tuning record for task in turning logs in log_filename""" - if not os.path.exists(log_filename): - return False - - dispatch_context = tvm.autotvm.task.ApplyHistoryBest(log_filename) - return dispatch_context._query_inside(task.target, task.workload) - - -def tune_autotvm_tasks(tasks, log_filename): - """Appends to log filename the best strategies for tasks""" - if len(tasks) == 0: - return +NDK_CC = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++") - measure_option = tvm.autotvm.measure_option( - builder=tvm.autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15), - runner=tvm.autotvm.RPCRunner( - RPC_KEY, host=RPC_TRACKER_HOST, port=RPC_TRACKER_PORT, number=100, timeout=15 - ), - ) - logging.info( - f"Using autotvm tuning for {len(tasks)} tasks with {AUTOTVM_NUM_TRIALS} trials, logging to {log_filename}" - ) - - # create tmp log file, starting with contents from existing log file - tmp_log_filename = log_filename + ".tmp" - if os.path.exists(tmp_log_filename): - os.remove(tmp_log_filename) - if os.path.exists(log_filename): - logging.info(f"Copying existing log {log_filename} to {tmp_log_filename}") - shutil.copy(log_filename, tmp_log_filename) - - for i, task in enumerate(reversed(tasks)): - prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) - logging.info(f"Considering task {task.name} {prefix}") - if is_already_tuned(task, tmp_log_filename): - logging.info(f"Re-using existing record for {task.name}") - continue - - logging.info(f"Using autotvm to tune {task.name}") - tuner_obj = tvm.autotvm.tuner.XGBTuner(task, loss_type="reg") - if os.path.exists(tmp_log_filename): - tuner_obj.load_history(tvm.autotvm.record.load_from_file(tmp_log_filename)) - - # do tuning - n_trial = min(AUTOTVM_NUM_TRIALS, len(task.config_space)) - tuner_obj.tune( - n_trial=n_trial, - early_stopping=AUTOTVM_EARLY_STOPPING, - measure_option=measure_option, - callbacks=[ - tvm.autotvm.callback.progress_bar(n_trial, prefix=prefix), - tvm.autotvm.callback.log_to_file(tmp_log_filename), - ], - ) - - # Pick best records and copy back to main log file - tvm.autotvm.record.pick_best(tmp_log_filename, log_filename) - os.remove(tmp_log_filename) - - logging.info("Done with autotvm tuning") - - -def autotvm_tune_module(mod, target, log_filename): - if log_filename == "": - logging.info("Not tuning with autotvm since disabled") - return - # Extract and tune any TVM kernels. BYOC partitions will have no tasks extracted. - logging.info("Extracting tasks from overall module") - tasks = extract_autotvm_tasks(mod, target) - logging.info(f"Auto-tuning {len(tasks)} tasks from overall module") - tune_autotvm_tasks(tasks, log_filename) - - -########### Drivers ########### - - -def compile_and_benchmark(label, model, targets, tmp_dir): - """Compile model for target and run it with profiling.""" - logging.info(f"Compiling {model['name']} using {label} with {targets}...") - mod = model["mod"] - mod = clml.preprocess_for_clml(mod) - exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"]) - lib = exe.mod - lib_path = os.path.join(tmp_dir, "lib.so") - logging.info(f"Exporting library to {lib_path}...") - lib.export_library(lib_path, cc=NDK_CROSS_COMPILER) +def get_rpc_remote(): + """Create remote rpc tracker and connect to available remote device""" tracker = rpc.connect_tracker(RPC_TRACKER_HOST, RPC_TRACKER_PORT) remote = tracker.request(RPC_KEY, priority=0, session_timeout=600) - ctx = remote.cl(0) - remote_path = "lib.so" - remote.upload(lib_path, target=remote_path) - lib = remote.load_module(remote_path) - vm_factory = tvm.runtime.vm.VirtualMachine(lib, ctx) - args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod["main"].params} - logging.info(f"Benchmarking for {model['name']} generated by {label}...") - profile = vm_factory.benchmark( - ctx, repeat=MEASURE_REPEAT, number=MEASURE_NUMBER, min_repeat_ms=0, **args - ) - logging.info(f"Benchmarked for {model['name']} generated by {label}: {profile}") - logging.info(f"RESULT: {label} | {model['name']} | {profile.median * 1e3}ms") - + return remote -# Custom cost function for Opencl RPC targets. -@register_func("tvm.relay.collage.opencl_cost_estimator") -def opencl_cost_estimator(mod, target): - mod = clml.preprocess_for_clml(mod) if "clml" == target.kind.name else mod - try: - # Build the module. - logging.info("Compiling module to estimate") - exe = tvm.relay.vm.compile(mod, target) - except RuntimeError as err: - # A build failure indicates the partition is not supported. - # eg trying to build an nn.batch_norm on GPU, which has no schedule since we assume it - # is only ever used with a tuple projection which is rewritten away. - logging.info("Assigning module infinite cost since unable to build: %s", err) - return math.inf - lib = exe.mod - tracker = rpc.connect_tracker(RPC_TRACKER_HOST, RPC_TRACKER_PORT) - remote = tracker.request(RPC_KEY, priority=0, session_timeout=600) - temp = utils.tempdir() - dso_binary = "dev_lib_cl.so" - dso_binary_path = temp.relpath(dso_binary) - ctx = remote.cl(0) - lib.export_library(dso_binary_path, cc=NDK_CROSS_COMPILER) - remote_path = dso_binary - remote.upload(dso_binary_path, target=remote_path) - lib = remote.load_module(remote_path) - - vm_factory = tvm.runtime.vm.VirtualMachine(lib, ctx) - func_name = "main" - main_args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod[func_name].params} - cost = vm_factory.benchmark( - ctx, repeat=5, number=20, min_repeat_ms=0, func_name=func_name, **main_args - ) - return cost.mean - - -def collage(model): +def collage(model, input_data): """Run the Collage partitioner for a set of Opencl Adreno related targets and profile the result""" logging.info(f"collage | {model['name']}") logging.info("-------------- BEGIN ORIGINAL --------------") logging.info(model["mod"]) logging.info("-------------- END ORIGINAL ----------------") - autotvm_tune_module(model["mod"], OPENCL, TUNING_LOG) - with optional_tuning_records(TUNING_LOG): + with autotvm.apply_history_best(TUNING_LOG): targets = [] targets.append(OPENCL) use_fp16 = model["main_dtype"] == "float16" - tmp_dir = tempfile.mkdtemp() targets.append(tvm.target.Target("clml", HOST)) # Register byoc fusion style for compiler with available @@ -262,32 +84,30 @@ def collage(model): config = tvm.target.make_compilation_config(ctxt, targets) with ctxt: mod = model["mod"] - mod = tvm.relay.transform.CapturePostDfsIndexInSpans()(mod) - logging.info("-------------- BEGIN INDEXED --------------") - logging.info(mod) - logging.info("-------------- END INDEXED ----------------") - # Register python custom cost function for targets in - # custom cost estimator module. - cost_estimator = CustomCostEstimator( - py_fn_estimator="tvm.relay.collage.opencl_cost_estimator" - ) - mod = tvm.relay.transform.CollagePartition(config, cost_estimator=cost_estimator)(mod) + """Collage partition with tvm opencl and clml target on rpc device""" + mod = tvm.relay.transform.CollagePartition( + config, + cost_estimator=CostEstimator( + host=RPC_TRACKER_HOST, port=RPC_TRACKER_PORT, rpc_key=RPC_KEY, ndk_cc=NDK_CC + ), + )(mod) partitioned_model = model.copy() partitioned_model["mod"] = mod logging.info("-------------- BEGIN PARTITIONED --------------") logging.info(partitioned_model["mod"]) logging.info("-------------- END PARTITIONED ----------------") - compile_and_benchmark("collage", partitioned_model, targets, tmp_dir) + return compile_and_run( + get_rpc_remote(), "collage", partitioned_model, targets, input_data + ) -def just_clml(model): +def just_clml(model, input_data): """Run partition_for_clml, complete the compilation with TVM, and profile the result.""" logging.info(f"just_clml | {model['name']}") logging.info("-------------- BEGIN ORIGINAL --------------") logging.info(model["mod"]) logging.info("-------------- END ORIGINAL ----------------") - tmp_dir = tempfile.mkdtemp() - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): + with tvm.transform.PassContext(opt_level=3): logging.info("Partitioning for CLML...") mod = tvm.relay.op.contrib.clml.partition_for_clml(model["mod"], model["params"]) partitioned_model = model.copy() @@ -298,20 +118,18 @@ def just_clml(model): targets = [] targets.append(OPENCL) targets.append(tvm.target.Target("clml", HOST)) - compile_and_benchmark("just_clml", partitioned_model, targets, tmp_dir) + return compile_and_run(get_rpc_remote(), "just_clml", partitioned_model, OPENCL, input_data) -def just_tvm(model): +def just_tvm(model, input_data): """Compile and profile using vanilla TVM.""" logging.info(f"just_tvm | {model['name']}") logging.info("-------------- BEGIN ORIGINAL --------------") logging.info(model["mod"]) logging.info("-------------- END ORIGINAL ----------------") - tmp_dir = tempfile.mkdtemp() - autotvm_tune_module(model["mod"], OPENCL, TUNING_LOG) - with optional_tuning_records(TUNING_LOG): - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - compile_and_benchmark("just_tvm", model, OPENCL, tmp_dir) + with autotvm.apply_history_best(TUNING_LOG): + with tvm.transform.PassContext(opt_level=3): + return compile_and_run(get_rpc_remote(), "just_tvm", model, OPENCL, input_data) def get_model(model_name, dtype): @@ -320,6 +138,16 @@ def get_model(model_name, dtype): mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=dtype) elif "resnet" in model_name: mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype=dtype) + + initializer = tvm.relay.testing.init.Xavier() + for param_name in list(params.keys()): + filter_data = np.zeros(params[param_name].shape).astype(params[param_name].dtype) + if len(filter_data.shape) > 1: + initializer("weight", filter_data) + else: + initializer("bias", filter_data) + params[param_name] = tvm.nd.array(filter_data) + if params: mod["main"] = bind_params_by_name(mod["main"], params) mod = tvm.relay.transform.FoldConstant()(mod) @@ -334,21 +162,32 @@ def get_model(model_name, dtype): ########### Runners ########### -@pytest.mark.parametrize("dtype", ["float32"]) +@pytest.mark.parametrize("dtype", ["float32", "float16"]) +@pytest.mark.parametrize("model_name", ["mobilenet", "resnet-50"]) @tvm.testing.requires_openclml -def run_resnet50(dtype): +def test_network_collage(model_name, dtype): + print("Network evaluating .. " + model_name + " " + dtype) + np.random.seed(0) + model = get_model(model_name, dtype) + input_data = {} + for name, shape in model["input_shapes"].items(): + input_data[name] = np.random.uniform(-1.0, 1.0, shape).astype(model["input_dtypes"][name]) + + clml_out = just_clml(model, input_data) + tvm_out = just_tvm(model, input_data) + """Check tvm and clml output correctness.""" + tvm_sort = np.argsort(tvm_out).flatten() + clml_sort = np.argsort(clml_out).flatten() + tvm.testing.assert_allclose(tvm_sort[-5:], clml_sort[-5:], rtol=0, atol=0) + logging.info("-------- TVM and CLML execution test passed ---------") - just_clml(get_model("resnet-50", dtype)) - just_tvm(get_model("resnet-50", dtype)) """Run Collage for tvm and clml compiler target.""" - collage(get_model("resnet-50", dtype)) + collage_out = collage(model, input_data) + collage_sort = np.argsort(collage_out).flatten() + """Check tvm and collage(tvm+clml) output correctness.""" + tvm.testing.assert_allclose(tvm_sort[-5:], collage_sort[-5:], rtol=0, atol=0) + logging.info("-------- Collage execution test passed ---------") -@pytest.mark.parametrize("dtype", ["float32"]) -@tvm.testing.requires_openclml -def run_mobilenetv1(dtype): - - just_clml(get_model("mobilenet", dtype)) - just_tvm(get_model("mobilenet", dtype)) - """Run Collage for tvm and clml compiler target.""" - collage(get_model("mobilenet", dtype)) +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relay/collage/demo_collage_partitioner.py b/tests/python/relay/collage/demo_collage_partitioner.py index 0b7c815a8806..1d2d9a095b73 100644 --- a/tests/python/relay/collage/demo_collage_partitioner.py +++ b/tests/python/relay/collage/demo_collage_partitioner.py @@ -276,10 +276,6 @@ def collage(model): config = tvm.target.make_compilation_config(ctxt, targets) with ctxt: mod = model["mod"] - mod = tvm.relay.transform.CapturePostDfsIndexInSpans()(mod) - logging.info("-------------- BEGIN INDEXED --------------") - logging.info(mod) - logging.info("-------------- END INDEXED ----------------") mod = tvm.relay.transform.CollagePartition(config)(mod) partitioned_model = model.copy() partitioned_model["mod"] = mod diff --git a/tests/python/relay/test_pass_collage_partition.py b/tests/python/relay/test_pass_collage_partition.py index f40631628ea5..20fc4e421d13 100644 --- a/tests/python/relay/test_pass_collage_partition.py +++ b/tests/python/relay/test_pass_collage_partition.py @@ -67,8 +67,6 @@ def run_collage( with pass_ctxt: config = make_compilation_config(pass_ctxt, targets) actual_mod = InferType()(input_mod) - # Capture indexes only to help debug failing tests - actual_mod = CapturePostDfsIndexInSpans()(actual_mod) actual_mod = CollagePartition(config, cost_estimator)(actual_mod) if not tvm.ir.structural_equal(actual_mod, expected_mod, map_free_vars=True):