diff --git a/benchmarks/tensorexpr/benchmark.py b/benchmarks/tensorexpr/benchmark.py new file mode 100644 index 0000000000000..e466abb3a54b5 --- /dev/null +++ b/benchmarks/tensorexpr/benchmark.py @@ -0,0 +1,136 @@ +import argparse +import itertools +import framework +import os +import types +import tensor_engine +#import normalization +import broadcast +#import reduction +import elementwise +#import softmax +#import pooling +#import conv +#import matmul + + +def main(): + parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, + description= +'''Benchmark operators in specific shapes. +Works only with Python3.\n A few examples: + * benchmark.py: runs all the default configs with all the benchmarks. + * benchmark.py reduce: runs all the default configs with all benchmark with a prefix 'reduce' + * benchmark.py layernorm_fwd_cpu_128_32_128_128: run a particular benchmark in that config''') + parser.add_argument('benchmark_names', type=str, default=None, nargs='*', + help='name of the benchmark to run') + parser.add_argument('--device', type=str, default='cpu,cuda', + help='a comma separated list of device names') + parser.add_argument('--mode', type=str, default='fwd,both', + help='a comma separated list of running modes') + parser.add_argument('--engine', type=str, default='pt', + help='the underlying tensor engine. only pt for now') + parser.add_argument('--jit_mode', type=str, default='trace', + help='the jit mode to use: one of {trace, none}') + parser.add_argument('--cuda_pointwise_loop_levels', type=int, default=None, + help='num of loop levesl for Cuda pointwise operations: 2 or 3') + parser.add_argument('--cuda_pointwise_block_count', type=int, default=None, + help='num of block for Cuda pointwise operations') + parser.add_argument('--cuda_pointwise_block_size', type=int, default=None, + help='num of blocks for Cuda pointwise operations') + parser.add_argument('--cuda_fuser', type=str, default='te', + help='The Cuda fuser backend to use: one of {te, old, none}') + + args = parser.parse_args() + + def set_global_threads(num_threads): + os.environ['OMP_NUM_THREADS'] = str(num_threads) + os.environ['MKL_NUM_THREADS'] = str(num_threads) + os.environ['TVM_NUM_THREADS'] = str(num_threads) + os.environ['NNC_NUM_THREADS'] = str(num_threads) + + devices = args.device.split(',') + # accept 'gpu' as an alternative as the 'cuda' device + devices = ['cuda' if device == 'gpu' else device for device in devices] + cpu_count = 0 + for index, device in enumerate(devices): + if device.startswith('cpu'): + cpu_count += 1 + if cpu_count > 1: + raise ValueError('more than one CPU device is not allowed: %d' % (cpu_count)) + if device == 'cpu': + continue + num_threads_str = device[3:] + try: + # see if the device is in 'cpu1' or 'cpu4' format + num_threads = int(num_threads_str) + set_global_threads(num_threads) + devices[index] = 'cpu' + except ValueError: + continue + + modes = args.mode.split(',') + + tensor_engine.set_engine_mode(args.engine) + + def run_default_configs(bench_cls, allow_skip=True): + for mode, device, config in itertools.product(modes, devices, bench_cls.default_configs()): + benchmark = bench_cls(mode, device, *config) + benchmark.jit_mode = args.jit_mode + if not benchmark.is_supported(): + if allow_skip: + continue + else: + raise ValueError('attempted to run an unsupported benchmark: %s' % (benchmark.desc())) + framework.run_benchmark(benchmark, args) + + benchmark_classes = framework.benchmark_classes + if not args.benchmark_names: + # by default, run all the benchmarks + for benchmark_cls in benchmark_classes: + run_default_configs(benchmark_cls, allow_skip=True) + else: + for name in args.benchmark_names: + # if the name is the prefix of a benchmark class, run all the benchmarks for that class + match_class_name = False + for bench_cls in benchmark_classes: + if name in bench_cls.module(): + match_class_name = True + run_default_configs(bench_cls, allow_skip=True) + + if match_class_name: + continue + + # if not a class module, parse the config and call it that way + match_class_name = False + for bench_cls in benchmark_classes: + cls_module = bench_cls.module() + if name.startswith(cls_module): + match_class_name = True + if name[len(cls_module)] != '_': + raise ValueError('invalid name: %s' % (name)) + config_str = name[(len(cls_module) + 1):] + config = config_str.split('_') + if len(config) < 2: + raise ValueError('invalid config: %s' % config) + mode, device = config[0:2] + #TODO: make sure virtual devices such as 'cpu1' and 'cpu4' are supported. + if mode not in ['fwd', 'both']: + raise ValueError('invalid mode: %s' % (mode)) + for i, entry in enumerate(config): + try: + value = int(entry) + config[i] = value + except ValueError: + pass + benchmark = bench_cls(*config) + benchmark.jit_mode = args.jit_mode + framework.run_benchmark(benchmark, args) + + if not match_class_name: + available_classes = ', '.join([bench_cls.module() for bench_cls in benchmark_classes]) + raise ValueError('invalid name: %s\nAvailable benchmark classes:\n%s' % (name, available_classes)) + + +if __name__== '__main__': + main() diff --git a/benchmarks/tensorexpr/broadcast.py b/benchmarks/tensorexpr/broadcast.py new file mode 100644 index 0000000000000..4816524c6928f --- /dev/null +++ b/benchmarks/tensorexpr/broadcast.py @@ -0,0 +1,264 @@ +import framework +import itertools +import numpy as np +import torch + + +class BroadcastMulBench(framework.Benchmark): + def __init__(self, mode, device, case, M, N, K): + super().__init__(mode, device) + self.case = case + self.M = M + self.N = N + self.K = K + + if case == 'row': + self.d1 = self.rand([M, N, 1], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([M, 1, K], device=device, requires_grad=self.requires_grad) + elif case == 'mid': + self.d1 = self.rand([M, N, 1], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([1, N, K], device=device, requires_grad=self.requires_grad) + elif case == 'col': + self.d1 = self.rand([M, 1, K], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([1, N, K], device=device, requires_grad=self.requires_grad) + else: + raise ValueError('invalid case: %s' % (case)) + + self.inputs = [self.d1, self.d2] + + def forward(self, d1, d2): + y = d1 + d2 + return y + + def reference(self): + return self.numpy(self.d1) + self.numpy(self.d2) + + def config(self): + return [self.M, self.N, self.K] + + @staticmethod + def default_configs(): + return [[128, 256, 128]] + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + algorithmic_count = 1 + else: + sol_count = (1) + (1) + algorithmic_count = 1 + (1 + 1) + + buffer_size = self.M * self.N * self.K * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + +class BroadcastRowBench(BroadcastMulBench): + def __init__(self, mode, device, M, N, K): + super(BroadcastRowBench, self).__init__(mode, device, 'row', M, N, K) + + @staticmethod + def module(): + return 'broadcast_row' + + +class BroadcastMidBench(BroadcastMulBench): + def __init__(self, mode, device, M, N, K): + super(BroadcastMidBench, self).__init__(mode, device, 'mid', M, N, K) + + @staticmethod + def module(): + return 'broadcast_mid' + + +class BroadcastColBench(BroadcastMulBench): + def __init__(self, mode, device, M, N, K): + super(BroadcastColBench, self).__init__(mode, device, 'col', M, N, K) + + @staticmethod + def module(): + return 'broadcast_col' + + +class BroadcastThreeArgs(framework.Benchmark): + def __init__(self, mode, device, M, N, K, L): + super().__init__(mode, device) + self.M = M + self.N = N + self.K = K + self.L = L + + self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad) + self.d3 = self.rand([L, K, 1, 1], device=device, requires_grad=self.requires_grad) + + self.inputs = [self.d1, self.d2, self.d3] + + def forward(self, d1, d2, d3): + y = d1 + d2 + d3 + return y + + def reference(self): + return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3) + + def config(self): + return [self.M, self.N, self.K, self.L] + + @staticmethod + def default_configs(): + return [[32, 16, 64, 128]] + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + algorithmic_count = 1 + else: + sol_count = (1) + (1) + algorithmic_count = 1 + (1 + 1 + 1) + + buffer_size = self.M * self.N * self.K * self.L * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def module(): + return 'broadcast_3args' + + +#framework.register_benchmark_class(BroadcastRowBench) +#framework.register_benchmark_class(BroadcastMidBench) +#framework.register_benchmark_class(BroadcastColBench) +#framework.register_benchmark_class(BroadcastThreeArgs) + +# TODO: merge this with elementwise bench +# A template class for elementwise operations. +# A derived class will override the class instance to customize its behavior. +class BroadcastBench(framework.Benchmark): + # List of customization class variables. + op_str = None + binary_op_pt_func = None + binary_op_np_func = None + unary_op_pt_func = None + unary_op_np_func = None + split_input = True + def __init__(self, mode, device, M, N, K): + super().__init__(mode, device) + self.M = M + self.N = N + self.K = K + self.d1 = self.rand([M, N], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([K, 1, N], device=device, requires_grad=self.requires_grad) + self.d3 = self.rand([M, N], device=device, requires_grad=self.requires_grad) + self.d4 = self.rand([K, M, 1], device=device, requires_grad=self.requires_grad) + self.inputs = [self.d1, self.d2, self.d3, self.d4] + + def _eval(self, d1, d2, d3, d4, binary_op, unary_op): + if not binary_op: + binary_op = lambda x, y: x + y + if not unary_op: + unary_op = lambda x: x + if self.split_input: + d1 = unary_op(d1) + d2 = unary_op(d2) + d3 = unary_op(d3) + d4 = unary_op(d4) + else: + d1, d2, d3, d4 = unary_op(d1), unary_op(d2), unary_op(d1 + 0.001), unary_op(d4) + a = binary_op(d1, d2) + b = binary_op(d3, d4) + c = a + b + return c + + def forward(self, d1, d2, d3, d4): + binary_op = self.__class__.binary_op_pt_func + unary_op = self.__class__.unary_op_pt_func + return self._eval(d1, d2, d3, d4, binary_op, unary_op) + + def reference(self): + binary_op = self.__class__.binary_op_np_func + unary_op = self.__class__.unary_op_np_func + [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] + return self._eval(d1, d2, d3, d4, binary_op, unary_op) + + def config(self): + return [self.M, self.N, self.K] + + @classmethod + def module(cls): + return 'broadcast_' + cls.op_str + + def memory_workload(self): + input_count = len(self.inputs) + if self.mode == 'fwd': + if self.split_input: + sol_count = 1 + algorithmic_count = 1 + else: + sol_count = 1 + algorithmic_count = 1 + else: + if self.split_input: + sol_count = 1 + algorithmic_count = input_count + else: + sol_count = 1 + algorithmic_count = input_count + + buffer_size = self.M * self.N * self.K * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def default_configs(): + return [[1 << 8, 1 << 7, 1 << 9]] + + +def register_broadcast_ops(): + binary_op_list = [ + ["mul", lambda a, b: a * b], + ["add", lambda a, b: a + b], + ["sub", lambda a, b: a - b], + ["div", lambda a, b: a / (b + 1e-4)], + ["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered + ["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)], + ["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)], + ] + + unary_op_list = [ + ["exp", lambda x: torch.exp(x), lambda x: np.exp(x)], + ["sin", lambda x: torch.sin(x), lambda x: np.sin(x)], + ["cos", lambda x: torch.cos(x), lambda x: np.cos(x)], + ] + + for split_input, binary_op in itertools.product([True, False], binary_op_list): + # Make a copy of BroadcastBench + if len(binary_op) == 2: + [op_str, op_pt_func] = binary_op + op_np_func = op_pt_func + elif len(binary_op) == 3: + [op_str, op_pt_func, op_np_func] = binary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('BroadcastBench_' + op_str, (BroadcastBench,), {}) + bm_cls.op_str = op_str + bm_cls.binary_op_pt_func = op_pt_func + bm_cls.binary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + for split_input, unary_op in itertools.product([True, False], unary_op_list): + # Make a copy of BroadcastBench + if len(unary_op) == 2: + [op_str, op_pt_func] = unary_op + op_np_func = op_pt_func + elif len(unary_op) == 3: + [op_str, op_pt_func, op_np_func] = unary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('BroadcastBench_' + op_str, (BroadcastBench,), {}) + bm_cls.op_str = op_str + bm_cls.unary_op_pt_func = op_pt_func + bm_cls.unary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + +register_broadcast_ops() + diff --git a/benchmarks/tensorexpr/conv.py b/benchmarks/tensorexpr/conv.py new file mode 100644 index 0000000000000..a9a318e76400c --- /dev/null +++ b/benchmarks/tensorexpr/conv.py @@ -0,0 +1,103 @@ +import framework + + +class ConvImplBench(framework.Benchmark): + def __init__(self, case, mode, device, kernel_size, N, iC, H, W, oC): + super().__init__(mode, device) + self.case = case + self.kernel_size = kernel_size + self.N = N + self.iC = iC + self.H = H + self.W = W + self.oC = oC + self.data = self.rand([N, iC, H, W], device=device, requires_grad=self.requires_grad) + if case == 'conv': + self.groups = 1 + elif case == 'depthwise_conv': + self.groups = iC + else: + raise ValueError('invalid case: %s' % (case)) + + self.conv = self.conv2d_layer(iC, oC, kernel_size, groups=self.groups) + if device != 'cpu': + self.to_device(self.conv, device) + + def forward(self): + y = self.conv(self.data) + return y + + def config(self): + return [self.kernel_size, self.N, self.iC, self.H, self.W, self.oC] + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = {'i': 1, 'o': 1, 'k': 1} + algorithmic_count = {'i': 1, 'o': 1, 'k': 1} + else: + sol_count = { + 'i': 1 + 1, + 'o': 1 + 1, + 'k': 1 + 1 + } + algorithmic_count = { + 'i': 1 + (1 + 1), + 'o': 1 + (1 + 1), + 'k': 1 + (1 + 1) + } + + buffer_size = { + 'i': self.N * self.iC * self.H * self.W * 4, + 'o': self.N * self.oC * self.H * self.W * 4, + 'k': self.oC * (self.iC / self.groups) * self.kernel_size * self.kernel_size * 4, + } + sol_size = 0 + algorithmic_size = 0 + for key in sol_count: + sol_size += buffer_size[key] * sol_count[key] + algorithmic_size += buffer_size[key] * algorithmic_count[key] + return { + 'sol': sol_size, + 'algorithmic': algorithmic_size + } + + def compute_workload(self): + if self.mode == 'fwd': + count = 1 + elif self.mode == 'both': + count = 1 + (1 + 1) + else: + raise ValueError('invalid mode: %s' % (self.mode)) + + op_count = self.N * self.iC / self.groups * self.oC * self.kernel_size * self.kernel_size * self.H * self.W + op_count *= 2 + + return op_count * count + + @staticmethod + def default_configs(): + return [ + [3, 64, 32, 128, 128, 64], + ] + + +class ConvBench(ConvImplBench): + def __init__(self, *args): + super().__init__('conv', *args) + + @staticmethod + def module(): + return 'conv' + + +class DepthwiseConvBench(ConvImplBench): + def __init__(self, *args): + super().__init__('depthwise_conv', *args) + + @staticmethod + def module(): + return 'depthwise_conv' + + +framework.register_benchmark_class(ConvBench) +framework.register_benchmark_class(DepthwiseConvBench) diff --git a/benchmarks/tensorexpr/elementwise.py b/benchmarks/tensorexpr/elementwise.py new file mode 100644 index 0000000000000..79db2608e0082 --- /dev/null +++ b/benchmarks/tensorexpr/elementwise.py @@ -0,0 +1,148 @@ +import framework +import itertools +import numpy as np +import torch + +# A template class for elementwise operations. +# A derived class will override the class instance to customize its behavior. +class ElementBench(framework.Benchmark): + # List of customization class variables. + op_str = None + binary_op_pt_func = None + binary_op_np_func = None + unary_op_pt_func = None + unary_op_np_func = None + split_input = True + def __init__(self, mode, device, N): + super().__init__(mode, device) + self.N = N + self.d1 = self.rand([N], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([N], device=device, requires_grad=self.requires_grad) + self.d3 = self.rand([N], device=device, requires_grad=self.requires_grad) + self.d4 = self.rand([N], device=device, requires_grad=self.requires_grad) + self.inputs = [self.d1, self.d2, self.d3, self.d4] + self.deterministic = ('rand' not in self.op_str) + + def _eval(self, d1, d2, d3, d4, binary_op, unary_op): + if not binary_op: + binary_op = lambda x, y: x + y + if not unary_op: + unary_op = lambda x: x + if self.split_input: + d1 = unary_op(d1) + d2 = unary_op(d2) + d3 = unary_op(d3) + d4 = unary_op(d4) + else: + d2 = unary_op(d1 + 0.001) + d3 = unary_op(d1 + 0.002) + d4 = unary_op(d1 + 0.003) + d1 = unary_op(d1) + a = binary_op(d1, d2) + b = binary_op(d3, d4) + c = a + b + return c + + def forward(self, d1, d2, d3, d4): + binary_op = self.__class__.binary_op_pt_func + unary_op = self.__class__.unary_op_pt_func + return self._eval(d1, d2, d3, d4, binary_op, unary_op) + + def reference(self): + binary_op = self.__class__.binary_op_np_func + unary_op = self.__class__.unary_op_np_func + [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]] + return self._eval(d1, d2, d3, d4, binary_op, unary_op) + + def config(self): + return [self.N] + + @classmethod + def module(cls): + return 'element_' + cls.op_str + + def memory_workload(self): + input_count = len(self.inputs) + if self.mode == 'fwd': + if self.split_input: + sol_count = input_count + 1 + algorithmic_count = input_count + 1 + else: + sol_count = 1 + 1 + algorithmic_count = 1 + 1 + if 'rand' in self.op_str: + sol_count = 1 + algorithmic_count = 1 + else: + if self.split_input: + sol_count = (input_count + 1) + (1 + input_count) + algorithmic_count = (input_count + 1) + ((2 + 1) * input_count) + else: + sol_count = 1 + 1 + algorithmic_count = 1 + 1 + if 'rand' in self.op_str: + sol_count = 1 + algorithmic_count = 1 + + buffer_size = self.N * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def default_configs(): + return [[1 << 27]] + + +def register_element_ops(): + binary_op_list = [ + ["mul", lambda a, b: a * b], + ["add", lambda a, b: a + b], + ["sub", lambda a, b: a - b], + ["div", lambda a, b: a / (b + 1e-4)], + ["pow", lambda a, b: torch.pow(a, b), lambda a, b: np.power(a, b)], # no fuson triggered + ["max", lambda a, b: torch.max(a, b), lambda a, b: np.maximum(a, b)], + ["min", lambda a, b: torch.min(a, b), lambda a, b: np.minimum(a, b)], + ] + + unary_op_list = [ + ["exp", lambda x: torch.exp(x), lambda x: np.exp(x)], + ["sin", lambda x: torch.sin(x), lambda x: np.sin(x)], + ["cos", lambda x: torch.cos(x), lambda x: np.cos(x)], + ["rand_like", lambda x: torch.rand_like(x), lambda x: np.random.rand(*x.shape)], + ] + + for split_input, binary_op in itertools.product([True, False], binary_op_list): + # Make a copy of ElementBench + if len(binary_op) == 2: + [op_str, op_pt_func] = binary_op + op_np_func = op_pt_func + elif len(binary_op) == 3: + [op_str, op_pt_func, op_np_func] = binary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('ElementBench_' + op_str, (ElementBench,), {}) + bm_cls.op_str = op_str + bm_cls.binary_op_pt_func = op_pt_func + bm_cls.binary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + for split_input, unary_op in itertools.product([True, False], unary_op_list): + # Make a copy of ElementBench + if len(unary_op) == 2: + [op_str, op_pt_func] = unary_op + op_np_func = op_pt_func + elif len(unary_op) == 3: + [op_str, op_pt_func, op_np_func] = unary_op + split_str = 'split' if split_input else 'shared' + op_str = split_str + '_' + op_str + bm_cls = type('ElementBench_' + op_str, (ElementBench,), {}) + bm_cls.op_str = op_str + bm_cls.unary_op_pt_func = op_pt_func + bm_cls.unary_op_np_func = op_np_func + bm_cls.split_input = split_input + framework.register_benchmark_class(bm_cls) + + +#framework.register_benchmark_class(ElementMulBench) +register_element_ops() + diff --git a/benchmarks/tensorexpr/framework.py b/benchmarks/tensorexpr/framework.py new file mode 100644 index 0000000000000..9acf671e7db6d --- /dev/null +++ b/benchmarks/tensorexpr/framework.py @@ -0,0 +1,187 @@ +import contextlib +import numpy as np +import os +import time +import tensor_engine +import torch + +class BenchmarkBase(object): + def __init__(self, mode, device): + self.mode = mode + self.device = device + if mode == 'both': + self.requires_grad = True + elif mode == 'fwd': + self.requires_grad = False + else: + raise ValueError('invalid mode: %s' % (mode)) + self.result_grad = None + self.grad_variables = [] + + def forward(self): + '''do one step worth of computation + ''' + raise ValueError('this method should be reimplemented by subclass') + + def check(self): + if not self.deterministic: + return + np.testing.assert_allclose( + self.reference(), self.numpy(self.compute()), atol=1e-2) + + def config(self): + '''returns an array for the current benchmark configs + ''' + raise ValueError('this method should be reimplemented by subclass') + + def desc(self): + '''return the description of the current benchmark + ''' + config = self.config() + config_str = '_'.join([str(x) for x in config]) + device = self.device + if 'NNC_NUM_THREADS' in os.environ: + num_threads_str = os.environ['NNC_NUM_THREADS'] + device += num_threads_str + return '%s: %s_%s_%s_%s' % (self.engine.mode, self.module(), self.mode, device, config_str) + + @staticmethod + def module(): + raise ValueError('this method should be reimplemented by subclass') + + def memory_workload(self): + raise ValueError('this method should be reimplemented by subclass') + + def compute_workload(self): + '''return the number of scalar operations it takes to finish the tensor op''' + return None + + @staticmethod + def default_configs(): + '''return a list of defualt configs for this benchmark''' + raise ValueError('this method should be reimplemented by subclass') + + def is_supported(self): + return True + + +class Benchmark(BenchmarkBase): + def __init__(self, mode, device): + super().__init__(mode, device) + self.engine = tensor_engine.get_engine() + self.engine.reset(device) + + # forward all member functions in self.engine to self + for method in dir(self.engine): + if not callable(getattr(self.engine, method)): + continue + # don't forward if this function is overriden here + if hasattr(self, method): + continue + # don't forward if it is a internal function + if method.startswith('_'): + continue + method_engine = getattr(self.engine, method) + setattr(self, method, method_engine) + + def rand(self, shape, device=None, requires_grad=False): + v = self.engine.rand(shape, device=device, requires_grad=requires_grad) + if requires_grad: + self.grad_variables.append(v) + return v + + def nchw_rand(self, shape, device=None, requires_grad=False): + v = self.engine.nchw_rand(shape, device=device, requires_grad=requires_grad) + if requires_grad: + self.grad_variables.append(v) + return v + + def compute(self): + if self.bm_jit: + return self.bm_jit(*self.inputs) + else: + return self.forward(*self.inputs) + + +@contextlib.contextmanager +def cuda_pointwise_context(loop_levels, block_count, block_size): + if loop_levels: + old_loop_levels = torch._C._jit_get_te_cuda_pointwise_loop_levels() + torch._C._jit_set_te_cuda_pointwise_loop_levels(loop_levels) + if block_count: + old_block_count = torch._C._jit_get_te_cuda_pointwise_block_count() + torch._C._jit_set_te_cuda_pointwise_block_count(block_count) + if block_size: + old_block_size = torch._C._jit_get_te_cuda_pointwise_block_size() + torch._C._jit_set_te_cuda_pointwise_block_size(block_size) + + yield + + if loop_levels: + torch._C._jit_set_te_cuda_pointwise_loop_levels(old_loop_levels) + if block_count: + torch._C._jit_set_te_cuda_pointwise_block_count(old_block_count) + if block_size: + torch._C._jit_set_te_cuda_pointwise_block_size(old_block_size) + + +def run_benchmark(benchmark, args): + torch._C._jit_override_can_fuse_on_gpu(args.cuda_fuser == 'old'); + torch._C._jit_set_texpr_fuser_enabled(args.cuda_fuser == 'te'); + with cuda_pointwise_context(args.cuda_pointwise_loop_levels, + args.cuda_pointwise_block_count, + args.cuda_pointwise_block_size): + run_benchmark_impl(benchmark) + + +def run_benchmark_impl(benchmark): + warmups = 10 + if benchmark.device == 'cuda': + iters = 1000 + else: + iters = 10 + engine = tensor_engine.get_engine() + + benchmark.bm_jit = None + for i in range(warmups + iters): + if i == warmups: + if benchmark.device == 'cuda': + engine.sync_cuda() + time_start = time.time() + + if i == 0: + if benchmark.jit_mode == 'trace': + benchmark.bm_jit = torch.jit.trace(benchmark.forward, + example_inputs=benchmark.inputs, check_trace=False) + if callable(getattr(benchmark, 'reference', None)): + benchmark.check() + else: + print(f"Warning: no reference result for {benchmark.module()}") + z = benchmark.compute() + if benchmark.mode == 'both': + if benchmark.result_grad is None: + benchmark.result_grad = engine.rand_like(z) + engine.backward([z], [benchmark.result_grad], benchmark.grad_variables) + + if benchmark.device == 'cuda': + engine.sync_cuda() + + duration = time.time() - time_start + iter_time = duration / iters + memory_workload = benchmark.memory_workload() + compute_workload = benchmark.compute_workload() + + msg = '%s: %.2f us, SOL %.2f GB/s, algorithmic %.2f GB/s' % ( + benchmark.desc(), iter_time * 1e6, + memory_workload['sol'] / iter_time / 1e9, + memory_workload['algorithmic'] / iter_time / 1e9, + ) + if compute_workload is not None: + msg += ', compute %.2f Gops/s' % (compute_workload / iter_time / 1e9) + print(msg) + + +benchmark_classes = [] + +def register_benchmark_class(benchmark_cls): + benchmark_classes.append(benchmark_cls) diff --git a/benchmarks/tensorexpr/matmul.py b/benchmarks/tensorexpr/matmul.py new file mode 100644 index 0000000000000..8469565e56c35 --- /dev/null +++ b/benchmarks/tensorexpr/matmul.py @@ -0,0 +1,57 @@ +import framework +import numpy as np + + +class MatMulBench(framework.Benchmark): + def __init__(self, mode, device, B, M, N, K): + super().__init__(mode, device) + self.B = B + self.M = M + self.N = N + self.K = K + self.d1 = self.rand([B, M, N], device=device, requires_grad=self.requires_grad) + self.d2 = self.rand([B, N, K], device=device, requires_grad=self.requires_grad) + + def forward(self): + y = self.matmul(self.d1, self.d2) + return y + + def reference(self): + return np.matmul(self.numpy(self.d1), self.numpy(self.d2)) + + def config(self): + return [self.B, self.M, self.N, self.K] + + @staticmethod + def module(): + return 'batch_matmul' + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + algorithmic_count = 1 + else: + sol_count = 1 + 1 + algorithmic_count = 1 + (1 + 1) + + buffer_size = self.B * self.M * self.N + self.B * self.M * self.N + self.B * self.N * self.K + buffer_size *= 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + def compute_workload(self): + if self.mode == 'fwd': + count = 1 + else: + count = 1 + (1 + 1) + + op_count = 2 * self.B * self.M * self.N * self.K + + return op_count * count + + + @staticmethod + def default_configs(): + return [[128, 64, 128, 256]] + + +framework.register_benchmark_class(MatMulBench) diff --git a/benchmarks/tensorexpr/normalization.py b/benchmarks/tensorexpr/normalization.py new file mode 100644 index 0000000000000..4cef570da983b --- /dev/null +++ b/benchmarks/tensorexpr/normalization.py @@ -0,0 +1,71 @@ +import framework +import tensor_engine + +class NormalizationBench(framework.Benchmark): + def __init__(self, mode, device, N, C, H, W): + super().__init__(mode, device) + self.N = N + self.C = C + self.H = H + self.W = W + + self.data = self.nchw_rand([self.N, self.C, self.H, self.W], device=device, requires_grad=self.requires_grad) + self.running_mean = self.rand([self.C], device=device) + self.running_var = self.rand([self.C], device=device) + self.training = (self.mode == 'both') + + def config(self): + return [self.N, self.C, self.H, self.W] + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + 1 + algorithmic_count = 2 + 1 + else: + sol_count = (1 + 1) + (1 + 1) + algorithmic_count = (2 + 1) + (3 + 1) + + buffer_size = self.N * self.C * self.H * self.W * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def default_configs(): + return [[128, 32, 128, 128]] + + +class BatchNormBench(NormalizationBench): + def forward(self): + y = self.batch_norm(self.data, self.running_mean, self.running_var, training=self.training) + return y + + @staticmethod + def module(): + return 'batchnorm' + + +class InstanceNormBench(NormalizationBench): + def forward(self): + y = self.instance_norm(self.data) + return y + + @staticmethod + def module(): + return 'instance_norm' + + def is_supported(self): + return tensor_engine.is_supported(self.instance_norm) + + +class LayerNormBench(NormalizationBench): + def forward(self): + y = self.layer_norm(self.data, [self.H, self.W]) + return y + + @staticmethod + def module(): + return 'layernorm' + + +framework.register_benchmark_class(BatchNormBench) +framework.register_benchmark_class(InstanceNormBench) +framework.register_benchmark_class(LayerNormBench) diff --git a/benchmarks/tensorexpr/pooling.py b/benchmarks/tensorexpr/pooling.py new file mode 100644 index 0000000000000..8d852d5b545d6 --- /dev/null +++ b/benchmarks/tensorexpr/pooling.py @@ -0,0 +1,60 @@ +import framework + + +class PoolingBench(framework.Benchmark): + def __init__(self, case, mode, device, kernel_size, N, C, H, W): + super().__init__(mode, device) + self.case = case + self.kernel_size = kernel_size + self.N = N + self.C = C + self.H = H + self.W = W + self.data = self.rand([N, C, H, W], device=device, requires_grad=self.requires_grad) + + def forward(self): + if self.case == 'maxpool': + y = self.max_pool2d(self.data, self.kernel_size, stride=1) + elif self.case == 'avgpool': + y = self.avg_pool2d(self.data, self.kernel_size, stride=1) + return y + + def config(self): + return [self.kernel_size, self.N, self.C, self.H, self.W] + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + 1 + algorithmic_count = 1 + 1 + else: + sol_count = (1 + 1) + (1 + 1) + algorithmic_count = (1 + 1) + (2 + 1) + + buffer_size = self.N * self.C * self.H * self.W * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def default_configs(): + return [[3, 16, 32, 256, 256]] + + +class MaxPoolBench(PoolingBench): + def __init__(self, *args): + super().__init__('maxpool', *args) + + @staticmethod + def module(): + return 'maxpool' + + +class AvgPoolBench(PoolingBench): + def __init__(self, *args): + super().__init__('avgpool', *args) + + @staticmethod + def module(): + return 'avgpool' + + +framework.register_benchmark_class(MaxPoolBench) +framework.register_benchmark_class(AvgPoolBench) diff --git a/benchmarks/tensorexpr/pt_engine.py b/benchmarks/tensorexpr/pt_engine.py new file mode 100644 index 0000000000000..e71e62bdb6c25 --- /dev/null +++ b/benchmarks/tensorexpr/pt_engine.py @@ -0,0 +1,60 @@ +import torch + + +class TorchTensorEngine(object): + def rand(self, shape, device=None, requires_grad=False): + return torch.rand(shape, device=device, requires_grad=requires_grad) + + def nchw_rand(self, shape, device=None, requires_grad=False): + return self.rand(shape, device=device, requires_grad=requires_grad) + + def reset(self, _): + pass + + def rand_like(self, v): + return torch.rand_like(v) + + def numpy(self, t): + return t.cpu().numpy() + + def mul(self, t1, t2): + return t1 * t2 + + def add(self, t1, t2): + return t1 + t2 + + def batch_norm(self, data, mean, var, training): + return torch.nn.functional.batch_norm(data, mean, var, training=training) + + def instance_norm(self, data): + return torch.nn.functional.instance_norm(data) + + def layer_norm(self, data, shape): + return torch.nn.functional.layer_norm(data, shape) + + def sync_cuda(self): + torch.cuda.synchronize() + + def backward(self, tensors, grad_tensors, _): + torch.autograd.backward(tensors, grad_tensors=grad_tensors) + + def sum(self, data, dims): + return torch.sum(data, dims) + + def softmax(self, data, dim=None): + return torch.nn.functional.softmax(data, dim) + + def max_pool2d(self, data, kernel_size, stride=1): + return torch.nn.functional.max_pool2d(data, kernel_size, stride=stride) + + def avg_pool2d(self, data, kernel_size, stride=1): + return torch.nn.functional.avg_pool2d(data, kernel_size, stride=stride) + + def conv2d_layer(self, ic, oc, kernel_size, groups=1): + return torch.nn.Conv2d(ic, oc, kernel_size, groups=groups) + + def matmul(self, t1, t2): + return torch.matmul(t1, t2) + + def to_device(self, module, device): + return module.to(device) diff --git a/benchmarks/tensorexpr/reduction.py b/benchmarks/tensorexpr/reduction.py new file mode 100644 index 0000000000000..a9243893b2e6d --- /dev/null +++ b/benchmarks/tensorexpr/reduction.py @@ -0,0 +1,81 @@ +import framework + + +class ReduceBench(framework.Benchmark): + def __init__(self, mode, device, case, M, N, K): + super().__init__(mode, device) + self.case = case + self.M = M + self.N = N + self.K = K + + self.data = self.rand([M, N, K], device=device, requires_grad=self.requires_grad) + if case == 'row': + self.dims = [1, 2] + elif case == 'mid': + self.dims = [0, 2] + elif case == 'col': + self.dims = [0, 1] + else: + raise ValueError('invalid case: %s' % case) + + def forward(self): + y = self.sum(self.data, self.dims) + return y + + def config(self): + return [self.M, self.N, self.K] + + @staticmethod + def default_configs(): + return [ + #[512, 512, 512], + [512, 64, 512], + ] + + @staticmethod + def module(): + return 'reduce' + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + algorithmic_count = 1 + else: + sol_count = (1) + (1) + algorithmic_count = 1 + 1 + + buffer_size = self.M * self.N * self.K * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + +class ReduceRowBench(ReduceBench): + def __init__(self, mode, device, M, N, K): + super(ReduceRowBench, self).__init__(mode, device, 'row', M, N, K) + + @staticmethod + def module(): + return 'reduce_row' + + +class ReduceMidBench(ReduceBench): + def __init__(self, mode, device, M, N, K): + super(ReduceMidBench, self).__init__(mode, device, 'mid', M, N, K) + + @staticmethod + def module(): + return 'reduce_mid' + + +class ReduceColBench(ReduceBench): + def __init__(self, mode, device, M, N, K): + super(ReduceColBench, self).__init__(mode, device, 'col', M, N, K) + + @staticmethod + def module(): + return 'reduce_col' + + +framework.register_benchmark_class(ReduceRowBench) +framework.register_benchmark_class(ReduceMidBench) +framework.register_benchmark_class(ReduceColBench) diff --git a/benchmarks/tensorexpr/softmax.py b/benchmarks/tensorexpr/softmax.py new file mode 100644 index 0000000000000..d9915365dc816 --- /dev/null +++ b/benchmarks/tensorexpr/softmax.py @@ -0,0 +1,42 @@ +import framework +import scipy.special + + +class SoftmaxBench(framework.Benchmark): + def __init__(self, mode, device, M, N): + super().__init__(mode, device) + self.M = M + self.N = N + self.data = self.rand([M, N], device=device, requires_grad=self.requires_grad) + + def forward(self): + y = self.softmax(self.data, dim=1) + return y + + def reference(self): + return scipy.special.softmax(self.numpy(self.data), axis=1) + + def config(self): + return [self.M, self.N] + + @staticmethod + def module(): + return 'softmax' + + def memory_workload(self): + if self.mode == 'fwd': + sol_count = 1 + 1 + algorithmic_count = 3 + 1 + else: + sol_count = (1 + 1) + (1 + 1) + algorithmic_count = (3 + 1) + (3 + 1) + + buffer_size = self.M * self.N * 4 + return {'sol': buffer_size * sol_count, 'algorithmic': buffer_size * algorithmic_count} + + @staticmethod + def default_configs(): + return [[128, 1<<16]] + + +framework.register_benchmark_class(SoftmaxBench) diff --git a/benchmarks/tensorexpr/tensor_engine.py b/benchmarks/tensorexpr/tensor_engine.py new file mode 100644 index 0000000000000..d27158cfbabec --- /dev/null +++ b/benchmarks/tensorexpr/tensor_engine.py @@ -0,0 +1,42 @@ +tensor_engine = None + +def unsupported(func): + def wrapper(self): + return func(self) + + wrapper.is_supported = False + return wrapper + + +def is_supported(method): + if hasattr(method, 'is_supported'): + return method.is_supported + return True + + +def set_engine_mode(mode): + global tensor_engine + if mode == 'tf': + import tf_engine + tensor_engine = tf_engine.TensorFlowEngine() + elif mode == 'pt': + import pt_engine + tensor_engine = pt_engine.TorchTensorEngine() + elif mode == 'topi': + import topi_engine + tensor_engine = topi_engine.TopiEngine() + elif mode == 'relay': + import relay_engine + tensor_engine = relay_engine.RelayEngine() + elif mode == 'nnc': + import nnc_engine + tensor_engine = nnc_engine.NncEngine() + else: + raise ValueError('invalid tensor engine mode: %s' % (mode)) + tensor_engine.mode = mode + + +def get_engine(): + if tensor_engine is None: + raise ValueError('use of get_engine, before calling set_engine_mode is illegal') + return tensor_engine diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4850c0dd8842a..96581a5002fd5 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -411,6 +411,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_autogradzero.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/subgraph_rewrite.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/tensorexpr_fuser.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp @@ -454,8 +455,42 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) ${TORCH_SRC_DIR}/csrc/jit/fuser/fallback.cpp ${TORCH_SRC_DIR}/csrc/jit/function.cpp ${TORCH_SRC_DIR}/csrc/jit/vararg_functions.cpp + + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/mem_arena.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/codegen.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/expr.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/eval.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/function.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_visitor.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/kernel.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_codegen.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_jit.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/native.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/types.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_printer.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/ir_mutator.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/schedule.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/tensor.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/unique_name_manager.cpp ) + if (USE_LLVM) + message(STATUS "Looking for LLVM in ${USE_LLVM}") + find_package(LLVM QUIET PATHS ${USE_LLVM} NO_DEFAULT_PATH) + + if (LLVM_FOUND) + message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") + message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + + include_directories(${LLVM_INCLUDE_DIRS}) + add_definitions(-DENABLE_LLVM ${LLVM_DEFINITIONS}) + endif (LLVM_FOUND) + endif (USE_LLVM) + + set_source_files_properties(${TORCH_SRC_DIR}/csrc/jit/tensorexpr/llvm_jit.cpp PROPERTIES COMPILE_FLAGS -fno-rtti) + + if (NOT INTERN_DISABLE_MOBILE_INTERP) set (MOBILE_SRCS ${TORCH_SRC_DIR}/csrc/jit/mobile/function.cpp @@ -521,10 +556,11 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE) if (USE_CUDA) list(APPEND Caffe2_GPU_SRCS - ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp ${TORCH_SRC_DIR}/csrc/autograd/profiler_cuda.cpp ${TORCH_SRC_DIR}/csrc/autograd/functions/comm.cpp ${TORCH_SRC_DIR}/csrc/cuda/comm.cpp + ${TORCH_SRC_DIR}/csrc/jit/fuser/cuda/fused_kernel.cpp + ${TORCH_SRC_DIR}/csrc/jit/tensorexpr/cuda_codegen.cpp ) add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) target_link_libraries(caffe2_nvrtc ${CUDA_NVRTC} ${CUDA_CUDA_LIB} ${CUDA_NVRTC_LIB}) @@ -622,6 +658,13 @@ endif() add_library(torch_cpu ${Caffe2_CPU_SRCS}) torch_compile_options(torch_cpu) # see cmake/public/utils.cmake +if (LLVM_FOUND) + llvm_map_components_to_libnames(LLVM_LINK_LIBS + support core analysis executionengine instcombine + scalaropts transformutils native orcjit) + target_link_libraries(torch_cpu PRIVATE ${LLVM_LINK_LIBS}) +endif (LLVM_FOUND) + # This is required for older versions of CMake, which don't allow # specifying add_library() without a list of source files set(DUMMY_EMPTY_FILE ${CMAKE_BINARY_DIR}/empty.cpp) @@ -755,6 +798,7 @@ ENDIF() if (BUILD_TEST AND NOT MSVC AND NOT USE_ROCM) add_subdirectory(${TORCH_ROOT}/test/cpp/jit ${CMAKE_BINARY_DIR}/test_jit) + add_subdirectory(${TORCH_ROOT}/test/cpp/tensorexpr ${CMAKE_BINARY_DIR}/test_tensorexpr) if (USE_DISTRIBUTED) add_subdirectory(${TORCH_ROOT}/test/cpp/rpc ${CMAKE_BINARY_DIR}/test_cpp_rpc) endif() diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt new file mode 100644 index 0000000000000..2413631bfb5f2 --- /dev/null +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -0,0 +1,40 @@ +set(TENSOREXPR_TEST_ROOT ${TORCH_ROOT}/test/cpp/tensorexpr) + +file(GLOB TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_ROOT}/test_*.cpp) +set(TENSOREXPR_TEST_SRCS ${TENSOREXPR_TEST_SRCS} PARENT_SCOPE) + +add_executable(test_tensorexpr + ${TORCH_ROOT}/test/cpp/common/main.cpp + ${TENSOREXPR_TEST_ROOT}/gtest.cpp + ${TENSOREXPR_TEST_ROOT}/padded_buffer.cpp + ${TENSOREXPR_TEST_SRCS}) + +target_link_libraries(test_tensorexpr PRIVATE torch gtest) +target_include_directories(test_tensorexpr PRIVATE ${ATen_CPU_INCLUDE}) + +if (USE_CUDA) + target_link_libraries(test_tensorexpr PRIVATE + ${CUDA_LIBRARIES} + ${CUDA_NVRTC_LIB} + ${CUDA_CUDA_LIB} + ${TORCH_CUDA_LIBRARIES}) + + target_compile_definitions(test_tensorexpr PRIVATE USE_CUDA) +elseif (USE_ROCM) + target_link_libraries(test_tensorexpr PRIVATE + ${ROCM_HIPRTC_LIB} + ${PYTORCH_HIP_HCC_LIBRARIES} + ${TORCH_CUDA_LIBRARIES}) + + target_link_libraries(test_tensorexpr PRIVATE caffe2_gpu) + + target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) +endif() + +if (INSTALL_TEST) + install(TARGETS test_tensorexpr DESTINATION bin) + # Install PDB files for MSVC builds + if (MSVC AND BUILD_SHARED_LIBS) + install(FILES $ DESTINATION bin OPTIONAL) + endif() +endif() diff --git a/test/cpp/tensorexpr/README.md b/test/cpp/tensorexpr/README.md new file mode 100644 index 0000000000000..055d2201b009d --- /dev/null +++ b/test/cpp/tensorexpr/README.md @@ -0,0 +1,55 @@ +# TensorExpr C++ Tests + +## How to add a new test +First, create a new test file. Test files should have be placed in this +directory, with a name that starts with `test_`, like `test_foo.cpp`. + +Here is an example test file you can copy-paste. +```cpp +#include + +// Tests go in torch::jit +namespace torch { +namespace jit { + +// 1. Test cases are void() functions. +// 2. They start with the prefix `test` +void testCaseOne() { + // ... +} + +void testCaseTwo() { + // ... +} +} +} +``` + +Then, register your test in `tests.h`: +```cpp +// Add to TH_FORALL_TESTS_CUDA instead for CUDA-requiring tests +#define TH_FORALL_TESTS(_) \ + _(ADFormulas) \ + _(Attributes) \ + ... + _(CaseOne) // note that the `test` prefix is omitted. + _(CaseTwo) +``` + +We glob all the test files together in `CMakeLists.txt` so that you don't +have to edit it every time you add a test. Unfortunately, this means that in +order to get the build to pick up your new test file, you need to re-run +cmake: +``` +python setup.py build --cmake +``` + +## How do I run the tests? +The following commands assume you are in PyTorch root. + + ```bash + # (re)build the test binary + ninja build/bin/test_tensorexpr + # run + build/bin/test_tensorexpr --gtest_filter='glob_style_filter*' + ``` diff --git a/test/cpp/tensorexpr/__init__.py b/test/cpp/tensorexpr/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/test/cpp/tensorexpr/gtest.cpp b/test/cpp/tensorexpr/gtest.cpp new file mode 100644 index 0000000000000..507c43337b021 --- /dev/null +++ b/test/cpp/tensorexpr/gtest.cpp @@ -0,0 +1,34 @@ +#include + +#include + +namespace torch { +namespace jit { + +#define TENSOREXPR_GTEST(name) \ + TEST(TensorExprTest, name) { \ + test##name(); \ + } +TH_FORALL_TESTS(TENSOREXPR_GTEST) +#undef TENSOREXPR_GTEST + +#ifdef ENABLE_LLVM +#define TENSOREXPR_GTEST_LLVM(name) \ + TEST(TensorExprTest, name##_LLVM) { \ + test##name(); \ + } +TH_FORALL_TESTS_LLVM(TENSOREXPR_GTEST_LLVM) +#undef TENSOREXPR_GTEST_LLVM +#endif + +#ifdef USE_CUDA +#define TENSOREXPR_GTEST_CUDA(name) \ + TEST(TensorExprTest, name##_CUDA) { \ + test##name(); \ + } +TH_FORALL_TESTS_CUDA(TENSOREXPR_GTEST_CUDA) +#undef TENSOREXPR_GTEST_CUDA +#endif + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.cpp b/test/cpp/tensorexpr/padded_buffer.cpp new file mode 100644 index 0000000000000..c903aa68223af --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.cpp @@ -0,0 +1,38 @@ +#include "test/cpp/tensorexpr/padded_buffer.h" + +#include + +#include + +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +int PaddedBufferBase::Index(const std::vector& indices) const { + DCHECK_EQ(dims_.size(), indices.size()); + int total_index = 0; + for (int i = 0; i < dims_.size(); i++) { + total_index += indices[i] * strides_[i]; + } + return total_index; +} + +PaddedBufferBase::PaddedBufferBase( + const std::vector& dims, + const std::string& name) + : dims_(dims), name_(name), strides_(dims.size()) { + for (int i = dims.size() - 1; i >= 0; --i) { + if (i == dims.size() - 1) { + strides_[i] = 1; + } else { + strides_[i] = strides_[i + 1] * dims[i + 1]; + } + } + total_size_ = strides_[0] * dims[0]; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/padded_buffer.h b/test/cpp/tensorexpr/padded_buffer.h new file mode 100644 index 0000000000000..a602ed2f4dc51 --- /dev/null +++ b/test/cpp/tensorexpr/padded_buffer.h @@ -0,0 +1,248 @@ +#pragma once + +#include +#include + +#include "torch/csrc/jit/tensorexpr/eval.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +struct DefaultPaddedValue; + +template <> +struct DefaultPaddedValue { + static const int kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const uint8_t kValue = static_cast(0xBE); +}; + +template <> +struct DefaultPaddedValue { + static const int16_t kValue = static_cast(0xBEEF); +}; + +template <> +struct DefaultPaddedValue { + static const int64_t kValue = static_cast(0xDEADBEEF); +}; + +template <> +struct DefaultPaddedValue { + static constexpr float kValue = 0.1357; +}; + +template <> +struct DefaultPaddedValue { + // at::Half ctor isn't constexpr, so just fill it with bits. + static constexpr uint16_t kValue = 1357; +}; + +template <> +struct DefaultPaddedValue { + static constexpr double kValue = 0.1357; +}; + +// A concrete base to be used in PaddedBase. +class PaddedBufferBase { + public: + const std::string& name() const { + return name_; + } + + int size() const { + return total_size_; + } + + int raw_size() const { + return total_size_ + 2 * kPaddingSize; + } + + virtual ~PaddedBufferBase() {} + + protected: + explicit PaddedBufferBase( + const std::vector& dims, + const std::string& name); + int Index(const std::vector& indices) const; + + std::vector dims_; + std::string name_; + std::vector strides_; + int total_size_; // total number of useful element, does not include the + // paddings + static constexpr int kPaddingSize = 64; +}; + +// A padded buffer with wartermarks for testing. +// The buffer carries padded watermarks on both sides to catch potential +// out-of-bounds writes. For read-only data that are not supposed to change, it +// can also make a backup and be compared later. +template +class PaddedBuffer : public PaddedBufferBase { + public: + PaddedBuffer(int d0, const std::string& name = "") + : PaddedBuffer(std::vector({d0}), name) {} + PaddedBuffer(int d0, int d1, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1}), name) {} + PaddedBuffer(int d0, int d1, int d2, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2}), name) {} + PaddedBuffer(int d0, int d1, int d2, int d3, const std::string& name = "") + : PaddedBuffer(std::vector({d0, d1, d2, d3}), name) {} + PaddedBuffer(const std::vector& dims, const std::string& name = "") + : PaddedBufferBase(dims, name) { + data_.resize(total_size_ + 2 * kPaddingSize, kPaddingValue); + } + PaddedBuffer(const PaddedBuffer& other, const std::string& name) + : PaddedBuffer(other) { + this->name_ = name; + } + + T* data() { + return data_.data() + kPaddingSize; + } + const T* data() const { + return const_cast(this)->data(); + } + T* raw_data() { + return data_.data(); + } + const T* raw_data() const { + return const_cast(this)->raw_data(); + } + T& operator()(int i0) { + // There is a bit performance impact with forming a vector here. But this + // data structure is for testing only, and not performance critical. + return this->operator()(std::vector({i0})); + } + const T& operator()(int i0) const { + return const_cast(this)->operator()(i0); + } + T& operator()(int i0, int i1) { + return this->operator()(std::vector({i0, i1})); + } + const T& operator()(int i0, int i1) const { + return const_cast(this)->operator()(i0, i1); + } + T& operator()(int i0, int i1, int i2) { + return this->operator()(std::vector({i0, i1, i2})); + } + const T& operator()(int i0, int i1, int i2) const { + return const_cast(this)->operator()(i0, i1, i2); + } + T& operator()(int i0, int i1, int i2, int i3) { + return this->operator()(std::vector({i0, i1, i2, i3})); + } + const T& operator()(int i0, int i1, int i2, int i3) const { + return const_cast(this)->operator()(i0, i1, i2, i3); + } + T& operator()(const std::vector& indices) { + return data_[kPaddingSize + Index(indices)]; + } + const T& operator()(const std::vector& indices) const { + return const_cast(this)->operator()(indices); + } + + template + friend void ExpectAllNear( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + float abs_error); + template + friend void ExpectAllEqual( + const PaddedBuffer& v1, + const PaddedBuffer& v2); + void Backup() { + backup_data_ = data_; + } + + // Verify the watermarks in the paddings are intact. + void ValidateWatermark() const { + for (int i = 0; i < kPaddingSize; i++) { + EXPECT_EQ(data_[i], kPaddingValue) + << "left-side watermark broken: " + << "index: " << i << ", name: " << name(); + EXPECT_EQ(data_[i + total_size_ + kPaddingSize], kPaddingValue) + << "right-side watermark broken: " + << "index: " << i << ", name: " << name(); + } + } + + void CheckBackup() const { + ValidateWatermark(); + DCHECK(backup_data_.size() == data_.size()) + << "Please make sure you have call Backup() before calling CheckBackup()"; + for (int i = 0; i < total_size_; i++) { + EXPECT_EQ(data_[i + kPaddingSize], backup_data_[i + kPaddingSize]) + << "mismatch against backup, " + << "index: " << i << ", name: " << name(); + } + } + + private: + std::vector data_; + std::vector backup_data_; + T kPaddingValue = DefaultPaddedValue::kValue; +}; + +template +inline CodeGen::CallArg::CallArg(const PaddedBuffer& buffer) + : ptr_(const_cast(buffer.data())) {} + +template +std::string CompareErrorMsg( + const PaddedBuffer& v1, + const PaddedBuffer& v2, + int index) { + std::ostringstream oss; + oss << "index: " << index << ", names: " << v1.name() << ", " << v2.name(); + return oss.str(); +} + +template +void ExpectAllEqual(const PaddedBuffer& f1, const PaddedBuffer& f2) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (int i = 0; i < total_size; i++) { + EXPECT_EQ(v1[kPaddingSize + i], v2[kPaddingSize + i]) + << CompareErrorMsg(f1, f2, i); + } +} + +template +void ExpectAllNear( + const PaddedBuffer& f1, + const PaddedBuffer& f2, + float abs_error) { + const std::vector& v1 = f1.data_; + const std::vector& v2 = f2.data_; + const int kPaddingSize = f1.kPaddingSize; + const int total_size = f1.total_size_; + ASSERT_EQ(v1.size(), v2.size()); + f1.ValidateWatermark(); + f2.ValidateWatermark(); + for (int i = 0; i < total_size; i++) { + ASSERT_NEAR(v1[kPaddingSize + i], v2[kPaddingSize + i], abs_error); + // << CompareErrorMsg(f1, f2, i); + } +} + + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_aten.cpp b/test/cpp/tensorexpr/test_aten.cpp new file mode 100644 index 0000000000000..7de638e37000d --- /dev/null +++ b/test/cpp/tensorexpr/test_aten.cpp @@ -0,0 +1,1095 @@ +#include "test/cpp/tensorexpr/test_base.h" +#include +#include +#include + +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "test/cpp/tensorexpr/padded_buffer.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +void testATen_cast_Float() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle to_float = Cast::make(kFloat, load_a); + Stmt* store_b = Store::make(b_buf, index, to_float, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), static_cast(i)) << "index: " << i; + } +} + +void testATennegInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle to_float = Sub::make(0, load_a); + Stmt* store_b = Store::make(b_buf, index, to_float, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), -static_cast(i)) << "index: " << i; + } +} + +void testATennegFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle to_float = Sub::make(0, load_a); + Stmt* store_b = Store::make(b_buf, index, to_float, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), -i) << "index: " << i; + } +} + +void testATenaddInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + ir_eval(a_v, b_v, c_v, d_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i; + } +} + +void testATenaddFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + Stmt* store_d = Store::make(d_buf, index, load_a + load_b * load_c, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + ir_eval(a_v, b_v, c_v, d_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), a_v(i) + b_v(i) * c_v(i)) << "index: " << i; + } +} + +void testATensubInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + ir_eval(a_v, b_v, c_v, d_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i; + } +} + +void testATensubFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + Stmt* store_d = Store::make(d_buf, index, load_a - load_b * load_c, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + ir_eval(a_v, b_v, c_v, d_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), a_v(i) - b_v(i) * c_v(i)) << "index: " << i; + } +} + +void testATenlerp() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + Stmt* store_d = + Store::make(d_buf, index, load_a + load_c * (load_b - load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_d); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf); + ir_eval(a_v, b_v, c_v, d_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), a_v(i) + c_v(i) * (b_v(i) - a_v(i))) << "index: " << i; + } +} + +void testATenaddcmulInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer e_buf(VarHandle("E", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + ExprHandle load_d = Load::make(d_buf, index, 1); + Stmt* store_e = + Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf, e_buf); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i; + EXPECT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i; + } +} + +void testATenaddcmulFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer d_buf(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer e_buf(VarHandle("E", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + ExprHandle load_c = Load::make(c_buf, index, 1); + ExprHandle load_d = Load::make(d_buf, index, 1); + Stmt* store_e = + Store::make(e_buf, index, load_a + load_b * load_c * load_d, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_e); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer d_v(kTotalSize); + PaddedBuffer e_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + c_v(i) = 3 * i + 2; + d_v(i) = 5 * i + 3; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf, d_buf, e_buf); + ir_eval(a_v, b_v, c_v, d_v, e_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), 3 * i + 2) << "index: " << i; + EXPECT_EQ(d_v(i), 5 * i + 3) << "index: " << i; + EXPECT_EQ(e_v(i), a_v(i) + b_v(i) * c_v(i) * d_v(i)) << "index: " << i; + } +} + +void testATenmulInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i; + } +} + +void testATenmulFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, load_a * load_b, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) * b_v(i)) << "index: " << i; + } +} + +void testATendivInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(b_v(i), i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i; + } +} + +void testATendivFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, load_a / load_b, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = 2 * i + 1; + b_v(i) = i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(b_v(i), i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) / b_v(i)) << "index: " << i; + } +} + +void testATenmaxInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), std::max(a_v(i), b_v(i))) << "index: " << i; + } +} + +void testATenmaxFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, Max::make(load_a, load_b, true), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), std::fmax(a_v(i), b_v(i))) << "index: " << i; + } +} + +void testATenminInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), std::min(a_v(i), b_v(i))) << "index: " << i; + } +} + +void testATenminFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make(c_buf, index, Min::make(load_a, load_b, true), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), std::fmin(a_v(i), b_v(i))) << "index: " << i; + } +} + +void testATen_sigmoid_backward() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make( + c_buf, index, load_a * load_b * (FloatImm::make(1.0f) - load_b), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) * b_v(i) * (1.0f - b_v(i))) << "index: " << i; + } +} + +void testATen_tanh_backward() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + ExprHandle load_b = Load::make(b_buf, index, 1); + Stmt* store_c = Store::make( + c_buf, index, load_a * (FloatImm::make(1.0f) - (load_b * load_b)), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_c); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + b_v(i) = 2 * i + 1; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 2 * i + 1) << "index: " << i; + EXPECT_EQ(c_v(i), a_v(i) * (1.0f - (b_v(i) * b_v(i)))) << "index: " << i; + } +} + +void testATenreciprocal() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, FloatImm::make(1.0f) / load_a, 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i) << "index: " << i; + EXPECT_EQ(b_v(i), 1.0f / i) << "index: " << i; + } +} + +void testATenreluInt() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kInt, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kInt, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, Max::make(load_a, 0, false), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i - 64) << "index: " << i; + EXPECT_EQ(b_v(i), std::max(a_v(i), 0)) << "index: " << i; + } +} + +void testATenreluFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make( + b_buf, + index, + Max::make(load_a, 0, false), // relu does not propagate nans + 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i - 64; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i - 64) << "index: " << i; + EXPECT_EQ(b_v(i), std::fmax(a_v(i), 0)) << "index: " << i; + } +} + +void testATenlogFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, log(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i + 10) << "index: " << i; + EXPECT_EQ(b_v(i), std::log(a_v(i))) << "index: " << i; + } +} + +void testATenlog10Float() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, log10(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i + 10) << "index: " << i; + EXPECT_EQ(b_v(i), std::log10(a_v(i))) << "index: " << i; + } +} + +void testATenlog2Float() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, log2(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i + 10; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i + 10) << "index: " << i; + EXPECT_EQ(b_v(i), std::log2(a_v(i))) << "index: " << i; + } +} + +void testATenexpFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, exp(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i; + EXPECT_EQ(b_v(i), std::exp(a_v(i))) << "index: " << i; + } +} + +void testATenerfFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, erf(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i; + EXPECT_EQ(b_v(i), std::erf(a_v(i))) << "index: " << i; + } +} + +void testATencosFloat() { + KernelScope kernel_scope; + const int kTotalSize = 128; + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make(a_buf, index, 1); + Stmt* store_b = Store::make(b_buf, index, cos(load_a), 1); + Stmt* stmt = For::make(index, 0, kTotalSize, store_b); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + + for (int i = 0; i < kTotalSize; ++i) { + a_v(i) = i / 10.0f; + } + + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf); + ir_eval(a_v, b_v); + + for (int i = 0; i < kTotalSize; ++i) { + EXPECT_EQ(a_v(i), i / 10.0f) << "index: " << i; + EXPECT_EQ(b_v(i), std::cos(a_v(i))) << "index: " << i; + } +} + +void testATeneqInt() { + KernelScope kernel_scope; + constexpr int N = 128; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kEQ), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +void testATengeInt() { + KernelScope kernel_scope; + constexpr int N = 128; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kGE), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +void testATengtInt() { + KernelScope kernel_scope; + constexpr int N = 128; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 6); + std::vector b_buffer(N, 3); + std::vector c_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kGT), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +void testATenleInt() { + KernelScope kernel_scope; + constexpr int N = 128; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kLE), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 1); +} + +void testATenltInt() { + KernelScope kernel_scope; + constexpr int N = 128; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 5); + std::vector b_buffer(N, 5); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kLT), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + assertAllEqual(c_buffer, 0); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_base.h b/test/cpp/tensorexpr/test_base.h new file mode 100644 index 0000000000000..69e60ec2e81fc --- /dev/null +++ b/test/cpp/tensorexpr/test_base.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +void ExpectAllNear( + const std::vector& v1, + const std::vector& v2, + V threshold, + const std::string& name = "") { + ASSERT_EQ(v1.size(), v2.size()); + for (int i = 0; i < v1.size(); i++) { + EXPECT_NEAR(v1[i], v2[i], threshold) + << "element index: " << i << ", name: " << name; + } +} + +template +static void assertAllEqual(const std::vector& vec, const T& val) { + for (auto const& elt : vec) { + ASSERT_EQ(elt, val); + } +} +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_cuda.cpp b/test/cpp/tensorexpr/test_cuda.cpp new file mode 100644 index 0000000000000..c612150f0ccb5 --- /dev/null +++ b/test/cpp/tensorexpr/test_cuda.cpp @@ -0,0 +1,330 @@ +#ifdef USE_CUDA + +#include +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include + +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr::schedule; + +template +void testCudaTestVectorAdd01_impl() { + KernelScope kernel_scope; + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Dtype dtype = ToDtype(); + Buffer a_buf("a", dtype, {num_iter, block_count, block_size}); + Buffer b_buf("b", dtype, {num_iter, block_count, block_size}); + Tensor* c = Compute( + "c", + { + {num_iter, "n"}, + {block_count, "b_id"}, + {block_size, "t_id"}, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return a_buf(n, b_id, t_id) + b_buf(n, b_id, t_id); + }); + Schedule sch({c}); + VarHandle b_id(c->arg(1)); + VarHandle t_id(c->arg(2)); + c->GPUExecConfig({b_id}, {t_id}); + Stmt* stmt = sch.Lower(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + const int N = block_count * block_size * num_iter; + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (int i = 0; i < N; i++) { + a_v(i) = ctype(i); + b_v(i) = ctype(i * 3 + 7); + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + ctype* a_dev = nullptr; + cudaMalloc(&a_dev, N * sizeof(ctype)); + ctype* b_dev = nullptr; + cudaMalloc(&b_dev, N * sizeof(ctype)); + ctype* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(ctype)); + cudaMemcpy(a_dev, a_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaMemcpy(b_dev, b_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaMemcpy(c_dev, c_v.data(), N * sizeof(ctype), cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cuda_cg(c_dev, a_dev, b_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(ctype), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(c_v, c_ref, 1e-5); + + cudaFree(a_dev); + cudaFree(b_dev); + cudaFree(c_dev); +} + +void testCudaTestVectorAdd01() { + // floating types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + + // integer types. + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); + testCudaTestVectorAdd01_impl(); +} + +static void testCudaTestVectorAdd02_impl(int N, int block_size) { + KernelScope kernel_scope; + Buffer a_buf("a", kFloat, {N}); + Buffer b_buf("b", kFloat, {N}); + Tensor* c = Compute( + "c", + { + {N, "N"}, + }, + [&](const VarHandle& n) { return a_buf(n) + b_buf(n); }); + Schedule sch({c}); + VarHandle n(c->function()->arg(0)); + VarHandle n_outer; + VarHandle n_inner; + c->SplitWithMask(n, block_size, true, &n_outer, &n_inner); + c->GPUExecConfig({n_outer}, {n_inner}); + Stmt* stmt = sch.Lower(); + CudaCodeGen cuda_cg(stmt, c, a_buf, b_buf); + PaddedBuffer a_v(N); + PaddedBuffer b_v(N); + PaddedBuffer c_v(N); + PaddedBuffer c_ref(N); + + for (int i = 0; i < N; i++) { + a_v(i) = i; + b_v(i) = i * 3 + 7; + c_ref(i) = a_v(i) + b_v(i); + } + + // TODO: move gpu support into PaddedBuffer + float* a_dev = nullptr; + cudaMalloc(&a_dev, N * sizeof(float)); + float* b_dev = nullptr; + cudaMalloc(&b_dev, N * sizeof(float)); + float* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(float)); + cudaMemcpy(a_dev, a_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(b_dev, b_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(c_dev, c_v.data(), N * sizeof(float), cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cuda_cg(c_dev, a_dev, b_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(c_v, c_ref, 1e-5); + + cudaFree(a_dev); + cudaFree(b_dev); + cudaFree(c_dev); +} + +void testCudaTestVectorAdd02() { + testCudaTestVectorAdd02_impl(1024, 128); + testCudaTestVectorAdd02_impl(1030, 128); +} + +void testCudaDynamicShape2D() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {m, n}); + Buffer b(VarHandle("b", kHandle), kFloat, {m, n}); + Tensor* c = + Compute("c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { + return a(i, j) + b(i, j); + }); + auto sch = Schedule::make({c}); + Stmt* s = sch.Lower(); + CudaCodeGen cg(s, {a, b, c, m, n}); + + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + float* aDev = nullptr; + float* bDev = nullptr; + float* cDev = nullptr; + cudaMalloc(&aDev, aData.size() * sizeof(aData[0])); + cudaMalloc(&bDev, bData.size() * sizeof(bData[0])); + cudaMalloc(&cDev, cData.size() * sizeof(cData[0])); + cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(bData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + cDev, + cData.data(), + cData.size() * sizeof(cData[0]), + cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, bDev, cDev, M, N}); + cudaDeviceSynchronize(); + + cudaMemcpy( + cData.data(), + cDev, + cData.size() * sizeof(cData[0]), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + + cudaFree(aDev); + cudaFree(bDev); + cudaFree(cDev); + }; + testWithSize(32, 32); + testWithSize(1, 16); + testWithSize(27, 13); +} + +void testCudaTestRand01() { + KernelScope kernel_scope; + const int num_iter = 3; + const int block_count = 16; + const int block_size = 128; + Tensor* c = Compute( + "c", + { + {num_iter, "n"}, + {block_count, "b_id"}, + {block_size, "t_id"}, + }, + [&](const VarHandle& n, const VarHandle& b_id, const VarHandle& t_id) { + return Intrinsics::make(IntrinsicsOp::kRand, kFloat); + }); + Schedule sch({c}); + VarHandle b_id(c->function()->arg(1)); + VarHandle t_id(c->function()->arg(2)); + c->GPUExecConfig({b_id}, {t_id}); + Stmt* stmt = sch.Lower(); + CudaCodeGen cuda_cg(stmt, c); + const int N = block_count * block_size * num_iter; + PaddedBuffer c_v(N); + + // TODO: move gpu support into PaddedBuffer + float* c_dev = nullptr; + cudaMalloc(&c_dev, N * sizeof(float)); + cudaDeviceSynchronize(); + + cuda_cg(c_dev); + + cudaDeviceSynchronize(); + cudaMemcpy(c_v.data(), c_dev, N * sizeof(float), cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + float sum1 = 0; + float sum2 = 0; + float sum3 = 0; + for (int i = 0; i < N; i++) { + float v = c_v.data()[i]; + sum1 += v; + sum2 += v * v; + sum3 += v * v * v; + EXPECT_TRUE(v >= 0 && v < 1) << "invalid value: " << i << ", " << v; + } + sum1 /= N; + sum2 /= N; + sum3 /= N; + float sum1_mean = 1.f / 2; + float sum2_mean = 1.f / 3; + float sum3_mean = 1.f / 4; + + EXPECT_NEAR(sum1, sum1_mean, 2e-2); + EXPECT_NEAR(sum2, sum2_mean, 2e-2); + EXPECT_NEAR(sum3, sum3_mean, 2e-2); + cudaFree(c_dev); +} + +void testCudaDynamicShapeSplit() { + KernelScope ks; + constexpr int N = 4096; + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Tensor* b = + Compute("b", {{n, "n"}}, [&](const VarHandle& i) { return a(i) * 2.0f; }); + auto sch = Schedule::make({b}); + VarHandle outer; + VarHandle inner; + b->SplitWithMask(VarHandle(b->function()->arg(0)), 1024, true, &outer, &inner); + b->GPUExecConfig({outer}, {inner}); + Stmt* s = sch.Lower(); + CudaCodeGen cg(s, {a, b, n}); + + std::vector aData(N, 1.0f); + std::vector bData(N, 1.0f); + float* aDev = nullptr; + float* bDev = nullptr; + cudaMalloc(&aDev, aData.size() * sizeof(aData[0])); + cudaMalloc(&bDev, bData.size() * sizeof(bData[0])); + cudaMemcpy( + aDev, + aData.data(), + aData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaMemcpy( + bDev, + bData.data(), + bData.size() * sizeof(aData[0]), + cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + + cg.call({aDev, bDev, N}); + cudaDeviceSynchronize(); + + cudaMemcpy( + bData.data(), + bDev, + bData.size() * sizeof(aData[0]), + cudaMemcpyDeviceToHost); + cudaDeviceSynchronize(); + + ExpectAllNear(bData, std::vector(N, 2.0f), 1e-7); + + cudaFree(aDev); + cudaFree(bDev); +} + +} // namespace jit +} // namespace torch + +#endif diff --git a/test/cpp/tensorexpr/test_expr.cpp b/test/cpp/tensorexpr/test_expr.cpp new file mode 100644 index 0000000000000..bc17e14a8de00 --- /dev/null +++ b/test/cpp/tensorexpr/test_expr.cpp @@ -0,0 +1,456 @@ +#include "test/cpp/tensorexpr/test_base.h" + +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "test/cpp/tensorexpr/test_utils.h" +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/function.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +#include +#include +#include +#include +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +using SimpleIRExprEval = ExprEval; + +void testExprBasicValueTest() { + KernelScope kernel_scope; + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + SimpleIRExprEval eval(c); + EXPECT_EQ(eval.value(), 5); +} + +void testExprBasicValueTest02() { + KernelScope kernel_scope; + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + SimpleIRExprEval eval(f); + EXPECT_EQ(eval.value(), -4.0f); +} + +void testExprLetTest01() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + ExprHandle result = Let::make(x, ExprHandle(3.f), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprLetTest02() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + ExprHandle e1 = Let::make(x, ExprHandle(3.f), body); + ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1); + SimpleIRExprEval eval(e2); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4 * 6)); +} + +void testExprLetStmtTest01() { + KernelScope kernel_scope; + Buffer a_buf("a", kFloat, {1}); + Buffer b_buf("b", kFloat, {1}); + + ExprHandle load_a = Load::make(a_buf, 0, 1); + VarHandle var = VarHandle("v", kFloat); + Stmt* store_b = Store::make(b_buf, 0, var, 1); + Stmt* let_store = LetStmt::make(var, load_a, store_b); + SimpleIREvaluator eval(let_store, a_buf, b_buf); + + PaddedBuffer a_v(1); + PaddedBuffer b_v(1); + PaddedBuffer b_ref(1); + + a_v(0) = 23; + b_ref(0) = a_v(0); + eval(a_v, b_v); + + ExpectAllNear(b_v, b_ref, 1e-5); +} + +static ExprHandle test_01(const ExprHandle& expr) { + return expr; +} + +void testExprIntTest() { + KernelScope kernel_scope; + VarHandle x("x", kInt); + ExprHandle value = ExprHandle(3); + ExprHandle body = ExprHandle(2) + (x * ExprHandle(3) + ExprHandle(4)); + ExprHandle result = Let::make(x, ExprHandle(3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprFloatTest() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle value = ExprHandle((float)3); + ExprHandle body = + ExprHandle((float)2) + (x * ExprHandle((float)3) + ExprHandle((float)4)); + ExprHandle result = Let::make(x, ExprHandle((float)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprByteTest() { + KernelScope kernel_scope; + VarHandle x("x", kByte); + ExprHandle value = ExprHandle((uint8_t)3); + ExprHandle body = ExprHandle((uint8_t)2) + + (x * ExprHandle((uint8_t)3) + ExprHandle((uint8_t)4)); + ExprHandle result = Let::make(x, ExprHandle((uint8_t)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprCharTest() { + KernelScope kernel_scope; + VarHandle x("x", kChar); + ExprHandle value = ExprHandle((int8_t)3); + ExprHandle body = ExprHandle((int8_t)2) + + (x * ExprHandle((int8_t)3) + ExprHandle((int8_t)4)); + ExprHandle result = Let::make(x, ExprHandle((int8_t)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprShortTest() { + KernelScope kernel_scope; + VarHandle x("x", kShort); + ExprHandle value = ExprHandle((int16_t)3); + ExprHandle body = ExprHandle((int16_t)2) + + (x * ExprHandle((int16_t)3) + ExprHandle((int16_t)4)); + ExprHandle result = Let::make(x, ExprHandle((int16_t)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprLongTest() { + KernelScope kernel_scope; + VarHandle x("x", kLong); + ExprHandle value = ExprHandle((int64_t)3); + ExprHandle body = ExprHandle((int64_t)2) + + (x * ExprHandle((int64_t)3) + ExprHandle((int64_t)4)); + ExprHandle result = Let::make(x, ExprHandle((int64_t)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprHalfTest() { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + ExprHandle value = ExprHandle((at::Half)3); + ExprHandle body = ExprHandle((at::Half)2) + + (x * ExprHandle((at::Half)3) + ExprHandle((at::Half)4)); + ExprHandle result = Let::make(x, ExprHandle((at::Half)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} + +void testExprDoubleTest() { + KernelScope kernel_scope; + VarHandle x("x", kDouble); + ExprHandle value = ExprHandle((double)3); + ExprHandle body = ExprHandle((double)2) + + (x * ExprHandle((double)3) + ExprHandle((double)4)); + ExprHandle result = Let::make(x, ExprHandle((double)3), body); + SimpleIRExprEval eval(result); + EXPECT_EQ(eval.value(), 2 + (3 * 3 + 4)); +} +void testExprVectorAdd01() { + KernelScope kernel_scope; + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b_buf(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c_buf(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + /* + Build the following: + for (int index = 0; index < kVectorCount; index++) { + store(c_buf, ramp(index * 8, 1, 8), + load(a_buf, ramp(index * 8, 1, 8) + + load(b_buf, ramp(index * 8, 1, 8)))) + } + */ + VarHandle index = VarHandle("index", kInt); + ExprHandle load_a = Load::make( + a_buf, + Ramp::make(index * kVectorSize, 1, kVectorSize), + Broadcast::make(1, kVectorSize)); + ExprHandle load_b = Load::make( + b_buf, + Ramp::make(index * kVectorSize, 1, kVectorSize), + Broadcast::make(1, kVectorSize)); + ExprHandle value = load_a + load_b; + Stmt* store_c = Store::make( + c_buf, + Ramp::make(index * kVectorSize, 1, kVectorSize), + value, + Broadcast::make(1, kVectorSize)); + Stmt* stmt = For::make(index, 0, kVectorCount, store_c); + + EXPECT_EQ(load_a.dtype(), Dtype(kFloat, kVectorSize)); + EXPECT_EQ(load_b.dtype(), Dtype(kFloat, kVectorSize)); + EXPECT_EQ(value.dtype(), Dtype(kFloat, kVectorSize)); + + PaddedBuffer a_v(kTotalSize); + PaddedBuffer b_v(kTotalSize); + PaddedBuffer c_v(kTotalSize); + PaddedBuffer c_ref(kTotalSize); + for (int i = 0; i < kTotalSize; i++) { + a_v(i) = i * i; + b_v(i) = i * i * 4; + c_ref(i) = a_v(i) + b_v(i); + } + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c_buf); + ir_eval(a_v, b_v, c_v); + ExpectAllNear(c_v, c_ref, 1e-5); +} + +void testExprCompareSelectEQ() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto memcpy_expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kEQ), + mask)); + + SimpleIREvaluator ir_eval(memcpy_expr, a, b, c); + ir_eval(a_buffer, b_buffer, c_buffer); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +void testExprSubstitute01() { + KernelScope kernel_scope; + ExprHandle x = Var::make("x", kFloat); + ExprHandle y = Var::make("y", kFloat); + ExprHandle e = (x - 1.0f) * (x + y + 2.0f); + + ExprHandle z = Var::make("z", kFloat); + ExprHandle e2 = Substitute(&e, {{x, z + 1.0f}}); + ExprHandle e2_ref = ((z + 1.0f) - 1.0f) * ((z + 1.0f) + y + 2.0f); + std::ostringstream oss; + oss << e2; + std::string e2_str = oss.str(); + + oss.str(""); + oss << e2_ref; + std::string e2_ref_str = oss.str(); + ASSERT_EQ(e2_str, e2_ref_str); +} + +void testExprMath01() { + KernelScope kernel_scope; + ExprHandle v = sin(ExprHandle(1.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "sin(1.f)"); + + SimpleIRExprEval eval(v); + float v_ref = std::sin(1.0f); + float res = eval.value(); + ASSERT_NEAR(res, v_ref, 1e-6); +} + +void testExprUnaryMath01() { + KernelScope kernel_scope; + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return sin(v); }, + [](float v) { return std::sin(v); }}, + {[](const ExprHandle& v) { return tan(v); }, + [](float v) { return std::tan(v); }}, + {[](const ExprHandle& v) { return asin(v); }, + [](float v) { return std::asin(v); }}, + {[](const ExprHandle& v) { return acos(v); }, + [](float v) { return std::acos(v); }}, + {[](const ExprHandle& v) { return atan(v); }, + [](float v) { return std::atan(v); }}, + {[](const ExprHandle& v) { return sinh(v); }, + [](float v) { return std::sinh(v); }}, + {[](const ExprHandle& v) { return cosh(v); }, + [](float v) { return std::cosh(v); }}, + {[](const ExprHandle& v) { return tanh(v); }, + [](float v) { return std::tanh(v); }}, + {[](const ExprHandle& v) { return exp(v); }, + [](float v) { return std::exp(v); }}, + {[](const ExprHandle& v) { return fabs(v); }, + [](float v) { return std::fabs(v); }}, + {[](const ExprHandle& v) { return log(v); }, + [](float v) { return std::log(v); }}, + {[](const ExprHandle& v) { return log2(v); }, + [](float v) { return std::log2(v); }}, + {[](const ExprHandle& v) { return log10(v); }, + [](float v) { return std::log10(v); }}, + {[](const ExprHandle& v) { return erf(v); }, + [](float v) { return std::erf(v); }}, + {[](const ExprHandle& v) { return sqrt(v); }, + [](float v) { return std::sqrt(v); }}, + {[](const ExprHandle& v) { return rsqrt(v); }, + [](float v) { return 1.0f / std::sqrt(v); }}, + {[](const ExprHandle& v) { return ceil(v); }, + [](float v) { return std::ceil(v); }}, + {[](const ExprHandle& v) { return floor(v); }, + [](float v) { return std::floor(v); }}, + {[](const ExprHandle& v) { return round(v); }, + [](float v) { return std::round(v); }}, + {[](const ExprHandle& v) { return trunc(v); }, + [](float v) { return std::trunc(v); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float input_v = 0.8765f; + ExprHandle v = test_config.func(ExprHandle(input_v)); + float v_ref = test_config.ref_func(input_v); + SimpleIRExprEval eval(v); + EXPECT_NEAR(eval.value(), v_ref, 1e-6) << "fail: " << v; + } +} + +void testExprBinaryMath01() { + KernelScope kernel_scope; + struct TestConfig { + std::function func; + std::function ref_func; + }; + + std::vector test_configs = { + {[](const ExprHandle& v1, const ExprHandle& v2) { return pow(v1, v2); }, + [](float v1, float v2) { return std::pow(v1, v2); }}, + {[](const ExprHandle& v1, const ExprHandle& v2) { return fmod(v1, v2); }, + [](float v1, float v2) { return std::fmod(v1, v2); }}, + }; + + for (const TestConfig& test_config : test_configs) { + const float v1 = 0.8765f; + float v2 = 1.2345f; + ExprHandle v_expr = test_config.func(ExprHandle(v1), ExprHandle(v2)); + float v_ref = test_config.ref_func(v1, v2); + SimpleIRExprEval eval(v_expr); + EXPECT_NEAR(eval.value(), v_ref, 1e-6) << "fail: " << v_expr; + } +} + +void testExprDynamicShapeAdd() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Buffer b(VarHandle("b", kHandle), kFloat, {n}); + Buffer c(VarHandle("c", kHandle), kFloat, {n}); + VarHandle i("i", kInt); + Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1)); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + SimpleIREvaluator(s, a, b, c, n)(aData, bData, cData, size); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +void testCond01() { + KernelScope kernel_scope; + const int N = 16; + PaddedBuffer a_v(N); + Buffer a_buf("a", kFloat, {N}); + VarHandle index = VarHandle("index", kInt); + Stmt* assign_x2 = Store::make(VarHandle(a_buf.data()), index, cast(index) * 2, 1); + Stmt* assign_x3 = Store::make(VarHandle(a_buf.data()), index, cast(index) * 3, 1); + ExprHandle even_cond = CompareSelect::make(Mod::make(index, 2), 0, kEQ); + Stmt* assign = Cond::make(even_cond, assign_x2, assign_x3); + Stmt* for_stmt = For::make(index, 0, N, assign); + SimpleIREvaluator(for_stmt, a_buf)(a_v); + + PaddedBuffer a_ref(N); + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + a_ref(i) = i * 2; + } else { + a_ref(i) = i * 3; + } + } + ExpectAllNear(a_v, a_ref, 1e-5); +} + +void testIfThenElse01() { + KernelScope kernel_scope; + ExprHandle v = ifThenElse(ExprHandle(1), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(1, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 1.0f); +} + +void testIfThenElse02() { + KernelScope kernel_scope; + ExprHandle v = ifThenElse(ExprHandle(0), ExprHandle(1.0f), ExprHandle(2.0f)); + + std::ostringstream oss; + oss << v; + ASSERT_EQ(oss.str(), "IfThenElse(0, 1.f, 2.f)"); + + SimpleIRExprEval eval(v); + ASSERT_EQ(eval.value(), 2.0f); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_ir_printer.cpp b/test/cpp/tensorexpr/test_ir_printer.cpp new file mode 100644 index 0000000000000..735e5d3f2d58f --- /dev/null +++ b/test/cpp/tensorexpr/test_ir_printer.cpp @@ -0,0 +1,80 @@ +#include "test/cpp/tensorexpr/test_base.h" +#include + +#include "torch/csrc/jit/tensorexpr/expr.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" + +#include +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; + +void testIRPrinterBasicValueTest() { + KernelScope kernel_scope; + ExprHandle a = IntImm::make(2), b = IntImm::make(3); + ExprHandle c = Add::make(a, b); + + std::stringstream ss; + ss << c; + EXPECT_EQ(ss.str(), "(2 + 3)"); +} + +void testIRPrinterBasicValueTest02() { + KernelScope kernel_scope; + ExprHandle a(2.0f); + ExprHandle b(3.0f); + ExprHandle c(4.0f); + ExprHandle d(5.0f); + ExprHandle f = (a + b) - (c + d); + + std::stringstream ss; + ss << f; + EXPECT_EQ(ss.str(), "((2.f + 3.f) - (4.f + 5.f))"); +} + +void testIRPrinterLetTest01() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + ExprHandle result = Let::make(x, ExprHandle(3.f), body); + + std::stringstream ss; + ss << result; + EXPECT_EQ(ss.str(), "(let x = 3.f in (2.f + ((x * 3.f) + 4.f)))"); +} + +void testIRPrinterLetTest02() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + ExprHandle e1 = Let::make(x, ExprHandle(3.f), body); + ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1); + + std::stringstream ss; + ss << e2; + EXPECT_EQ( + ss.str(), "(let y = 6.f in (let x = 3.f in (2.f + ((x * 3.f) + (4.f * y)))))"); +} + +void testIRPrinterCastTest() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + ExprHandle e1 = Let::make(x, Cast::make(kInt, ExprHandle(3.f)), body); + ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1); + + std::stringstream ss; + ss << e2; + EXPECT_EQ( + ss.str(), + "(let y = 6.f in (let x = int(3.f) in (2.f + ((x * 3.f) + (4.f * y)))))"); +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_llvm.cpp b/test/cpp/tensorexpr/test_llvm.cpp new file mode 100644 index 0000000000000..e48ea2934eb2f --- /dev/null +++ b/test/cpp/tensorexpr/test_llvm.cpp @@ -0,0 +1,1054 @@ +#ifdef ENABLE_LLVM +#include "test/cpp/tensorexpr/test_base.h" + +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "test/cpp/tensorexpr/test_utils.h" +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/function.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/llvm_codegen.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +#include + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr::schedule; + +using LLVMExprEval = ExprEval; + + +// Typed tests, can't use gtest params here due to the way we instantiate tests. +#define TEST_LLVM_SCALAR_TYPES(_) \ + _(uint8_t, Byte, 24) \ + _(int8_t, Char, -20) \ + _(int16_t, Short, 3332) \ + _(int, Int, 123456) \ + _(int64_t, Long, 2631563121321) \ + _(float, Float, 0.122) \ + _(double, Double, 0.21312) \ + _(at::Half, Half, 0.128f) + + +#define IMM_TEST(Type, Name, Val) \ + void testLLVM##Name##ImmTest() { \ + KernelScope kernel_scope; \ + auto a = Name##Imm::make(Val); \ + LLVMExprEval cg(a); \ + if (std::is_floating_point()) { \ + EXPECT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + EXPECT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(IMM_TEST) +#undef IMM_TEST + +#define ADD_TEST(Type, Name, Val) \ + void testLLVM##Name##AddTest() { \ + KernelScope kernel_scope; \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make(Val * 2); \ + auto c = Add::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + EXPECT_NEAR(cg.value(), Val * 3, 0.1); \ + } else { \ + EXPECT_EQ(cg.value(), Val * 3); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(ADD_TEST) +#undef ADD_TEST + +#define SUB_TEST(Type, Name, Val) \ + void testLLVM##Name##SubTest() { \ + KernelScope kernel_scope; \ + auto a = Name##Imm::make(Val * 2); \ + auto b = Name##Imm::make(Val); \ + auto c = Sub::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + EXPECT_NEAR(cg.value(), Val, 0.1); \ + } else { \ + EXPECT_EQ(cg.value(), Val); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(SUB_TEST) +#undef SUB_TEST + +#define MUL_TEST(Type, Name, Val) \ + void testLLVM##Name##MulTest() { \ + KernelScope kernel_scope; \ + auto a = Name##Imm::make(Val); \ + auto b = Name##Imm::make((Type)4); \ + auto c = Mul::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + EXPECT_NEAR(cg.value(), Val * 4, 0.1); \ + } else { \ + EXPECT_EQ(cg.value(), Val * 4); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(MUL_TEST) +#undef MUL_TEST + +#define DIV_TEST(Type, Name, Val) \ + void testLLVM##Name##DivTest() { \ + KernelScope kernel_scope; \ + auto a = Name##Imm::make((Type)6); \ + auto b = Name##Imm::make((Type)3); \ + auto c = Div::make(a, b); \ + LLVMExprEval cg(c); \ + if (std::is_floating_point()) { \ + EXPECT_NEAR(cg.value(), 2, 0.1); \ + } else { \ + EXPECT_EQ(cg.value(), 2); \ + } \ + } +TEST_LLVM_SCALAR_TYPES(DIV_TEST) +#undef DIV_TEST + +void testLLVMIntToFloatCastTest() { + KernelScope kernel_scope; + auto a = IntImm::make(2); + auto b = Cast::make(kFloat, a); + LLVMExprEval cg(b, {}); + EXPECT_EQ(cg.value(), 2.0); +} + +void testLLVMFloatToIntCastTest() { + KernelScope kernel_scope; + auto a = FloatImm::make(2.0); + auto b = Cast::make(kInt, a); + LLVMExprEval cg(b); + EXPECT_EQ(cg.value(), 2); +} + +void testLLVMIntToLongCastTest() { + KernelScope kernel_scope; + auto a = IntImm::make(12345); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + EXPECT_EQ(cg.value(), 12345); +} + +void testLLVMByteToCharCastTest() { + KernelScope kernel_scope; + auto a = ByteImm::make(250); + auto b = Cast::make(kChar, a); + LLVMExprEval cg(b); + EXPECT_EQ(cg.value(), (int8_t)250); +} + +void testLLVMHalfToLongCastTest() { + KernelScope kernel_scope; + auto a = HalfImm::make(2.0); + auto b = Cast::make(kLong, a); + LLVMExprEval cg(b); + EXPECT_EQ(cg.value(), 2); +} + +void testLLVMByteToDoubleCastTest() { + KernelScope kernel_scope; + auto a = ByteImm::make(2); + auto b = Cast::make(kDouble, a); + LLVMExprEval cg(b); + EXPECT_EQ(cg.value(), 2); +} + +void testLLVMLetTest01() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f)); + ExprHandle result = Let::make(x, ExprHandle(3.f), body); + LLVMExprEval cg(result, {}); + EXPECT_EQ(cg.value(), 2.f + (3.f * 3.f + 4.f)); +} + +void testLLVMLetTest02() { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle value = ExprHandle(3.f); + ExprHandle body = + ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f) * y); + ExprHandle e1 = Let::make(x, ExprHandle(3.f), body); + ExprHandle e2 = Let::make(y, ExprHandle(6.f), e1); + LLVMExprEval cg(e2, {}); + EXPECT_EQ(cg.value(), 2.f + (3.f * 3.f + 4.f * 6.f)); +} + +void testLLVMLetTestMultitype() { + KernelScope kernel_scope; + VarHandle x("x", kByte); + VarHandle y("y", kHalf); + ExprHandle value = ExprHandle((short)3); + ExprHandle body = ExprHandle((double)2.f) + + (x * ExprHandle(3) + ExprHandle((int64_t)4) * y); + ExprHandle e1 = Let::make(x, ExprHandle((uint8_t)3), body); + ExprHandle e2 = Let::make(y, ExprHandle((at::Half)6.f), e1); + LLVMExprEval cg(e2, {}); + EXPECT_EQ(cg.value(), 2.f + (3 * 3 + 4 * 6.f)); +} + +void testLLVMBufferTest() { + KernelScope kernel_scope; + Buffer a(VarHandle("A", kHandle), kFloat, {32}); + std::vector v(5); + std::vector args({v.data()}); + auto rv = IntImm::make(0); + LLVMExprEval cg(rv, {a}); + EXPECT_EQ(cg.value(args), 0); +} + +void testLLVMBlockTest() { + KernelScope kernel_scope; + Buffer a(VarHandle("A", kHandle), kInt, {32}); + std::vector v = {1, 2}; + std::vector args({v.data()}); + + auto block = Block::make({ + Store::make(a, IntImm::make(0), IntImm::make(3), IntImm::make(1)), + Store::make(a, IntImm::make(1), IntImm::make(4), IntImm::make(1)), + Store::make(a, IntImm::make(0), IntImm::make(4), IntImm::make(1)), + }); + + LLVMCodeGen cg(block, {a}); + EXPECT_EQ(cg.value(args), 0); + EXPECT_EQ(v[0], 4); + EXPECT_EQ(v[1], 4); +} + +void testLLVMLoadStoreTest() { + KernelScope kernel_scope; + Buffer a(VarHandle("A", kHandle), kInt, {1}); + Buffer b(VarHandle("B", kHandle), kInt, {1}); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + + auto store = Store::make( + b, + IntImm::make(0), + Load::make(a, IntImm::make(0), IntImm::make(1)), + IntImm::make(1)); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + EXPECT_EQ(cg.value(args), 0); + EXPECT_EQ(a_buffer[0], 42); + EXPECT_EQ(b_buffer[0], 42); +} + +void testLLVMIfThenElseTest() { + KernelScope kernel_scope; + Buffer a(VarHandle("A", kHandle), kInt, {1}); + Buffer b(VarHandle("B", kHandle), kInt, {1}); + Buffer c(VarHandle("C", kHandle), kInt, {1}); + std::vector a_buffer = {42}; + std::vector b_buffer = {-11}; + std::vector c_buffer = {1}; + + auto store = Store::make( + b, + IntImm::make(0), + IfThenElse::make( + Load::make(c, IntImm::make(0), IntImm::make(1)), // cond + Load::make(a, IntImm::make(0), IntImm::make(1)), // then + IntImm::make(0)), // else + IntImm::make(1)); + LLVMCodeGen cg(store, {a, b, c}); + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + EXPECT_EQ(cg.value(args), 0); + EXPECT_EQ(a_buffer[0], 42); + EXPECT_EQ(b_buffer[0], 42); +} + +void testLLVMVecLoadStoreTest() { + KernelScope kernel_scope; + Buffer a(VarHandle("A", kHandle), kInt, {1}); + Buffer b(VarHandle("B", kHandle), kInt, {1}); + std::vector a_buffer = {1, 1, 1, 1}; + std::vector b_buffer = {2, 2, 2, 2}; + + auto store = Store::make( + b, + Ramp::make(0, 1, 4), + Load::make(a, Ramp::make(0, 1, 4), Broadcast::make(IntImm::make(1), 4)), + Broadcast::make(IntImm::make(1), 4)); + LLVMCodeGen cg(store, {a, b}); + std::vector args({a_buffer.data(), b_buffer.data()}); + EXPECT_EQ(cg.value(args), 0); + EXPECT_EQ(a_buffer[0], 1); + EXPECT_EQ(a_buffer[1], 1); + EXPECT_EQ(a_buffer[2], 1); + EXPECT_EQ(a_buffer[3], 1); + EXPECT_EQ(b_buffer[0], 1); + EXPECT_EQ(b_buffer[1], 1); + EXPECT_EQ(b_buffer[2], 1); + EXPECT_EQ(b_buffer[3], 1); +} + +void testLLVMMemcpyTest() { + KernelScope kernel_scope; + constexpr int N = 32; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + std::vector a_buffer(N, 42); + std::vector b_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = + For::make(i, 0, N, Store::make(b, i, Load::make(a, i, mask), mask)); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 42); + assertAllEqual(b_buffer, 42); +} + +void testLLVMBzeroTest() { + KernelScope kernel_scope; + constexpr int N = 32; + Buffer b(VarHandle("B", kHandle), kInt, {N}); + std::vector b_buffer(N, 11); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make(i, 0, N, Store::make(b, i, IntImm::make(0), mask)); + + LLVMCodeGen cg(expr, {b}); + + std::vector args({b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(b_buffer, 0); +} + +void testLLVMElemwiseAdd() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Add::make(Load::make(a, i, mask), Load::make(b, i, mask)), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 42); +} + +void testLLVMElemwiseAddFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make(c, i, Load::make(a, i, mask) + Load::make(b, i, mask), mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 42.0f); +} + +void testLLVMElemwiseLog10Float() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + std::vector a_buffer(N, 10.0f); + std::vector b_buffer(N, 2.0f); + + auto mask = Broadcast::make(IntImm::make(1), 4); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N / 4, + Store::make( + b, + Ramp::make(i * 4, 1, 4), + log10(Load::make(a, Ramp::make(i * 4, 1, 4), mask)), + mask)); + + LLVMCodeGen cg(expr, {a, b}); + + std::vector args({a_buffer.data(), b_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + assertAllEqual(a_buffer, 10.0f); + assertAllEqual(b_buffer, 1.0f); +} + +void testLLVMElemwiseMaxInt() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 41); +} + +void testLLVMElemwiseMinInt() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41); + assertAllEqual(b_buffer, 1); + assertAllEqual(c_buffer, 1); +} + +void testLLVMElemwiseMaxNumFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 41.0f); +} + +void testLLVMElemwiseMaxNumNaNFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Max::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +void testLLVMElemwiseMinNumFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +void testLLVMElemwiseMinNumNaNFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Min::make(Load::make(a, i, mask), Load::make(b, i, mask), false), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +#if 1 // LLVM doesn't currently have implementations for maximum/minimum on x86 +void testLLVMElemwiseMaximumFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 41.0f); +} + +void testLLVMElemwiseMaximumNaNFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Max::make(Load::make(a, i, mask), Load::make(b, i, mask), true), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + for (int i = 0; i < N; ++i) { + ASSERT_TRUE(std::isnan(a_buffer[i])); + ASSERT_TRUE(std::isnan(c_buffer[i])); + } +} + +void testLLVMElemwiseMinimumFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, 41); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + assertAllEqual(a_buffer, 41.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1.0f); +} + +void testLLVMElemwiseMinimumNaNFloat() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kFloat, {N}); + std::vector a_buffer(N, NAN); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 1); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + Min::make(Load::make(a, i, mask), Load::make(b, i, mask), true), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + for (int i = 0; i < N; ++i) { + ASSERT_TRUE(std::isnan(a_buffer[i])); + ASSERT_TRUE(std::isnan(c_buffer[i])); + } +} +#endif + +void testLLVMCompareSelectIntEQ() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kInt, {N}); + Buffer b(VarHandle("B", kHandle), kInt, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 1); + std::vector b_buffer(N, 1); + std::vector c_buffer(N, 0); + std::vector c_ref(N, 1); + + for (int i = 0; i < N / 2; i++) { + b_buffer[i] = 0; + c_ref[i] = 0; + } + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kEQ), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1); + for (int i = 0; i < N; i++) { + ASSERT_EQ(c_ref[i], c_buffer[i]); + } +} + +void testLLVMCompareSelectFloatEQ() { + KernelScope kernel_scope; + constexpr int N = 1024; + Buffer a(VarHandle("A", kHandle), kFloat, {N}); + Buffer b(VarHandle("B", kHandle), kFloat, {N}); + Buffer c(VarHandle("C", kHandle), kInt, {N}); + std::vector a_buffer(N, 1.0f); + std::vector b_buffer(N, 1.0f); + std::vector c_buffer(N, 0); + + auto mask = IntImm::make(1); + VarHandle i("i", kInt); + auto expr = For::make( + i, + 0, + N, + Store::make( + c, + i, + CompareSelect::make( + Load::make(a, i, mask), + Load::make(b, i, mask), + CompareSelectOperation::kEQ), + mask)); + + LLVMCodeGen cg(expr, {a, b, c}); + + std::vector args({a_buffer.data(), b_buffer.data(), c_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + + ASSERT_EQ(a_buffer.size(), N); + ASSERT_EQ(b_buffer.size(), N); + ASSERT_EQ(c_buffer.size(), N); + + assertAllEqual(a_buffer, 1.0f); + assertAllEqual(b_buffer, 1.0f); + assertAllEqual(c_buffer, 1); +} + +void testLLVMStoreFloat() { + KernelScope kernel_scope; + Buffer result(VarHandle("result", kHandle), kFloat, {1}); + std::vector result_buffer = {0.0f}; + auto expr = Store::make( + result, IntImm::make(0), FloatImm::make(3.14f), IntImm::make(1)); + LLVMCodeGen cg(expr, {result}); + std::vector args({result_buffer.data()}); + ASSERT_EQ(cg.value(args), 0); + EXPECT_EQ(result_buffer[0], 3.14f); +} + +void testLLVMSimpleMath01() { + KernelScope kernel_scope; + const int N = 1024; + Tensor* tensor = Compute( + "f", {{N, "i"}}, [](const VarHandle& i) { return cast(i * i + 1); }); + Schedule sch = Schedule::make({tensor}); + Stmt* stmt = sch.Lower(); + Buffer f_buf(VarHandle(tensor->func_var()), kFloat, {N}); + LLVMCodeGen cg(stmt, {f_buf}); + + PaddedBuffer f_v(N, "f_v"); + std::vector args({f_v.data()}); + int value = cg.value(args); + ASSERT_EQ(value, 0); + PaddedBuffer f_ref(N, "f_ref"); + for (int i = 0; i < N; i++) { + f_ref(i) = i * i + 1; + } + ExpectAllNear(f_v, f_ref, 1e-5); +} + +void testLLVMComputeMul() { + KernelScope kernel_scope; + const int N = 1024; + Buffer a(VarHandle("a", kHandle), kFloat, {N}); + Buffer b(VarHandle("b", kHandle), kFloat, {N}); + Tensor* c = Compute("c", {{N, "i"}}, [&](const VarHandle& i) { + return Load::make(a, i, 1) * Load::make(b, i, 1); + }); + + Buffer c_buf(VarHandle(c->func_var()), kFloat, {N}); + Schedule sch = Schedule::make({c}); + Stmt* s = sch.Lower(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector a_vec(N, 21.0f); + std::vector b_vec(N, 2.0f); + std::vector c_vec(N, 0.0f); + std::vector args({a_vec.data(), b_vec.data(), c_vec.data()}); + ASSERT_EQ(cg.value(args), 0); + assertAllEqual(c_vec, 42.0f); +} + +void testLLVMBroadcastAdd() { + KernelScope kernel_scope; + const int M = 32; + const int N = 1024; + Buffer a(VarHandle("a", kHandle), kFloat, {M, N}); + Buffer b(VarHandle("b", kHandle), kFloat, {N}); + Tensor* c = + Compute("c", {{M, "i"}, {N, "j"}}, [&](const VarHandle& i, const VarHandle& j) { + ExprHandle mask(1); + return Load::make(a, i * N + j, mask) + Load::make(b, j, mask); + }); + + Buffer c_buf(VarHandle(c->func_var()), kFloat, {M, N}); + Schedule sch = Schedule::make({c}); + Stmt* s = sch.Lower(); + + LLVMCodeGen cg(s, {a, b, c_buf}); + + std::vector av(M * N); + std::iota(av.begin(), av.end(), 0); + std::vector bv(N); + std::iota(bv.begin(), bv.end(), 0); + std::vector cv(M * N, 0); + std::vector args({av.data(), bv.data(), cv.data()}); + ASSERT_EQ(cg.value(args), 0); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]); + } + } +} + +void testLLVMDynamicShapeAdd() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Buffer b(VarHandle("b", kHandle), kFloat, {n}); + Buffer c(VarHandle("c", kHandle), kFloat, {n}); + VarHandle i("i", kInt); + Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1)); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector args({aData.data(), bData.data(), cData.data(), &size}); + cg.value(args); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +void testLLVMBindDynamicShapeAdd() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Buffer b(VarHandle("b", kHandle), kFloat, {n}); + Buffer c(VarHandle("c", kHandle), kFloat, {n}); + VarHandle i("i", kInt); + Stmt* s = For::make(i, 0, n, Store::make(c, i, a(i) + b(i), 1)); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + LLVMCodeGen cg(s, {a, b, c, n}); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +void testLLVMTensorDynamicShapeAdd() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t size) { + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {n}); + Buffer b(VarHandle("b", kHandle), kFloat, {n}); + Tensor* c = + Compute("c", {{n, "n"}}, [&](const VarHandle& i) { return a(i) + b(i); }); + Schedule sch = Schedule::make({c}); + Stmt* s = sch.Lower(); + LLVMCodeGen cg(s, {a, b, c, n}); + std::vector aData(size, 1.0f); + std::vector bData(size, 2.0f); + std::vector cData(size, 0.0f); + cg.call({aData, bData, cData, size}); + ExpectAllNear(cData, std::vector(size, 3.0f), 1e-7); + }; + testWithSize(1); + testWithSize(16); + testWithSize(37); +} + +void testLLVMDynamicShape2D() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {m, n}); + Buffer b(VarHandle("b", kHandle), kFloat, {m, n}); + Tensor* c = + Compute("c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { + return a(i, j) + b(i, j); + }); + auto sch = torch::jit::tensorexpr::schedule::Schedule::make({c}); + Stmt* s = sch.Lower(); + LLVMCodeGen cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +} // namespace jit +} // namespace torch + +#endif // ENABLE_LLVM diff --git a/test/cpp/tensorexpr/test_schedule.cpp b/test/cpp/tensorexpr/test_schedule.cpp new file mode 100644 index 0000000000000..1cb1136fabf08 --- /dev/null +++ b/test/cpp/tensorexpr/test_schedule.cpp @@ -0,0 +1,549 @@ +#include +#include +#include +#include +#include "test/cpp/tensorexpr/test_base.h" + +#include "test/cpp/tensorexpr/padded_buffer.h" +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/function.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { + +using namespace torch::jit::tensorexpr; +using namespace torch::jit::tensorexpr::schedule; + +void testExprSimple01() { + KernelScope kernel_scope; + Tensor* tensor = + Compute("f", {{16, "X"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + VarHandle x(tensor->function()->arg(0)); + VarHandle y(tensor->function()->arg(1)); + Schedule sch = Schedule::make({tensor}); + VarHandle x_outer; + VarHandle x_inner; + VarHandle x_tail; + TensorOperation* tail_op; + tensor->SplitWithTail(x, 2, true, &x_outer, &x_inner, &x_tail, &tail_op); + + VarHandle x_2; + VarHandle x_1; + VarHandle x_tail_2; + TensorOperation* tail_op_2; + tensor->SplitWithTail(x_outer, 2, true, &x_2, &x_1, &x_tail_2, &tail_op_2); +} + +void testExprLower01() { + KernelScope kernel_scope; + Tensor* tensor = + Compute("f", {{16, "x"}, {5, "y"}}, [](const VarHandle& x, const VarHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }); + VarHandle x(tensor->function()->arg(0)); + VarHandle y(tensor->function()->arg(1)); + Schedule sch = Schedule::make({tensor}); + Stmt* stmt = sch.Lower(); + std::ostringstream oss; + oss << stmt; + ASSERT_GT(oss.str().size(), 20); + ASSERT_LT(oss.str().size(), 200); +} + +void testExprSimple02() { + KernelScope kernel_scope; + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor* tensor = Compute("f", {{26, "x"}, {5, "y"}}, func); + VarHandle x(tensor->function()->arg(0)); + VarHandle y(tensor->function()->arg(1)); + Schedule sch = Schedule::make({tensor}); + VarHandle x_outer; + VarHandle x_inner; + VarHandle x_tail; + TensorOperation* tail_op; + tensor->SplitWithTail(x, 4, true, &x_outer, &x_inner, &x_tail, &tail_op); + + Stmt* stmt = sch.Lower(); + std::ostringstream oss; + oss << *stmt; +// ASSERT_GT(oss.str().size(), 200); +// ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("x_outer", kInt); + VarHandle x_inner("x_inner", kInt); + VarHandle y("y", kInt); + VarHandle x_tail("x_tail", kInt); + VarHandle f("f", kHandle); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(26) - 0) / 4; + Stmt* stmt1 = For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make( + y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1)))); + ExprHandle x_2 = x_tail + x_outer_end * 4; + Stmt* stmt2 = For::make( + x_tail, + 0, + (ExprHandle(26) - 0) % 4, + For::make(y, 0, 5, Store::make(f, x_2 * 5 + y * 1, func(x_2, y), 1))); + Stmt* stmt = Block::make({stmt1, stmt2}); + + std::ostringstream oss_ref; + oss_ref << *stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(26, 5, "f_v"); + PaddedBuffer f_ref(26, 5, "f_res"); + + SimpleIREvaluator ir_eval(stmt, tensor); + ir_eval(f_v); + + for (int x = 0; x < 26; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +void testExprSplitWithTailNone() { + KernelScope kernel_scope; + auto func = [](const ExprHandle& x, const ExprHandle& y) { + return ExprHandle(1.0f) + cast(x) * x + cast(y) * y; + }; + Tensor* tensor = Compute("f", {{24, "x"}, {5, "y"}}, func); + VarHandle x = VarHandle(tensor->function()->arg(0)); + VarHandle y = VarHandle(tensor->function()->arg(1)); + Schedule sch = Schedule::make({tensor}); + VarHandle x_outer; + VarHandle x_inner; + VarHandle x_tail; + TensorOperation* tail_op; + tensor->SplitWithTail(x, 4, true, &x_outer, &x_inner, &x_tail, &tail_op); + + Stmt* stmt = sch.Lower(); + std::ostringstream oss; + oss << stmt; + ASSERT_GT(oss.str().size(), 200); + ASSERT_LT(oss.str().size(), 600); + + { + // Compare to a reference loop structure structure. + VarHandle x_outer("x_outer", kInt); + VarHandle x_inner("x_inner", kInt); + VarHandle y("y", kInt); + VarHandle x_tail("x_tail", kInt); + VarHandle f("f", kHandle); + ExprHandle x_1 = x_outer * 4 + x_inner; + ExprHandle x_outer_end = (ExprHandle(24) - 0) / 4; + Stmt* stmt = For::make( + x_outer, + 0, + x_outer_end, + For::make( + x_inner, + 0, + 4, + For::make( + y, 0, 5, Store::make(f, x_1 * 5 + y * 1, func(x_1, y), 1)))); + //Stmt stmt = Block::make({stmt1, stmt2}); + + std::ostringstream oss_ref; + oss_ref << stmt; + ASSERT_EQ(oss.str(), oss_ref.str()); + } + + { + PaddedBuffer f_v(24, 5, "f_v"); + PaddedBuffer f_ref(24, 5, "f_res"); + + SimpleIREvaluator ir_eval(stmt, tensor); + ir_eval(f_v); + + for (int x = 0; x < 24; x++) { + for (int y = 0; y < 5; y++) { + f_ref(x, y) = 1 + x * x + y * y; + } + } + + ExpectAllNear(f_v, f_ref, 1e-5); + } +} + +void testExprSplitWithMask01() { + KernelScope kernel_scope; + const int M = 26; + const int N = 5; + Buffer a_buf("a", kFloat, {M, N}); + Buffer b_buf("b", kFloat, {M, N}); + Tensor* tensor = + Compute("f", {{M, "m"}, {N, "n"}}, [&](const ExprHandle& m, const ExprHandle& n) { + return a_buf(m, n) + b_buf(m, n) + 1.0f; + }); + VarHandle m(tensor->function()->arg(0)); + VarHandle n(tensor->function()->arg(1)); + VarHandle n_outer; + VarHandle n_inner; + + Schedule sch({tensor}); + tensor->SplitWithMask(n, 4, true, &n_outer, &n_inner); + + Stmt* stmt = sch.Lower(); + + PaddedBuffer a_v(M, N, "a"); + PaddedBuffer b_v(M, N, "b"); + PaddedBuffer c_v(M, N, "c"); + PaddedBuffer c_ref(M, N, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 2 * m; + b_v(m, n) = 3 * n; + c_ref(m, n) = a_v(m, n) + b_v(m, n) + 1.0f; + } + } + + SimpleIREvaluator(stmt, a_buf, b_buf, tensor)(a_v, b_v, c_v); + + ExpectAllNear(c_v, c_ref, 1e-5); +} + +void testScheduleBroadcastAddBuffer() { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + Buffer a_buf("a", kFloat, {M, N}); + Buffer b_buf("b", kFloat, {N, K}); + Tensor* c = Compute( + "broadcast_add", + {{M, "m"}, {N, "n"}, {K, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) + b_buf(n, k); + }); + Schedule sch({c}); + Stmt* stmt = sch.Lower(); + + PaddedBuffer a_v(M, N, "a_v"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + a_v(m, n) = 7 * m * n; + } + } + a_v.Backup(); + + PaddedBuffer b_v(N, K, "b_v"); + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + b_v(n, k) = 11 * n * k; + } + } + b_v.Backup(); + + PaddedBuffer c_v(M, N, K, "c_buf"); + SimpleIREvaluator ir_eval(stmt, a_buf, b_buf, c); + ir_eval(a_v, b_v, c_v); + + a_v.CheckBackup(); + b_v.CheckBackup(); + PaddedBuffer c_ref(M, N, K, "c_ref"); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + c_ref(m, n, k) = 7 * m * n + 11 * n * k; + } + } + } + ExpectAllNear(c_v, c_ref, 1e-5); +} + +void testScheduleFunctionCall01() { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + Buffer a_buf("a", kFloat, {M, N}); + Buffer b_buf("b", kFloat, {N, K}); + Tensor* c = Compute( + "broadcast_add", + {{M, "m"}, {N, "n"}, {K, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) + b_buf(n, k); + }); + Tensor* d = Compute( + "d", + {{M, "m"}, {N, "n"}, {K, "k"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { return c->call(m, n, k) + 1; }); + + Schedule sch({d}); + Stmt* stmt = sch.Lower(); + std::ostringstream oss; + oss << stmt; + ASSERT_GT(oss.str().size(), 100); + + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N, K); + PaddedBuffer d_v(M, N, K); + PaddedBuffer d_ref(M, N, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + b_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + for (int k = 0; k < K; k++) { + d_ref(i, j, k) = a_v(i, j) + b_v(j, k) + 1; + } + } + } + + SimpleIREvaluator eval(stmt, a_buf, b_buf, d); + eval(a_v, b_v, d_v); + + ExpectAllNear(d_v, d_ref, 1e-5); +} + +static std::string remove_space(const std::string& str) { + std::string str_new = str; + str_new.erase( + remove_if(str_new.begin(), str_new.end(), isspace), str_new.end()); + return str_new; +} + +void InlineFunc01Helper(const std::vector& inline_order) { + KernelScope kernel_scope; + const int M = 4; + const int N = 5; + const int K = 6; + Buffer a_buf("a", kFloat, {M, N}); + Buffer b_buf("b", kFloat, {N, K}); + Buffer c_buf("c", kFloat, {M, N}); + Buffer d_buf("d", kFloat, {M, K}); + + Tensor* x = Compute( + "x", + {{M, "m1"}, {N, "n1"}, {K, "k1"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) * b_buf(n, k); + }); + Tensor* y = Compute( + "y", + {{M, "m2"}, {N, "n2"}, {K, "k2"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return c_buf(m, n) * d_buf(m, k) + x->call(m, n, k); + }); + Tensor* z = Compute( + "z", + {{M, "m3"}, {N, "n3"}, {K, "k3"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return x->call(m, n, k) + y->call(m, n, k); + }); + + Schedule sch({z}); + for (const std::string& order : inline_order) { + if (order == "x") { + x->ComputeInline(); + } else if (order == "y") { + y->ComputeInline(); + } else { + throw std::runtime_error("Invalid order: " + order); + } + } + Stmt* stmt = sch.Lower(); + + std::ostringstream oss; + oss << stmt; + std::string str1 = remove_space(oss.str()); + + { + PaddedBuffer a_v(M, N); + PaddedBuffer b_v(N, K); + PaddedBuffer c_v(M, N); + PaddedBuffer d_v(M, K); + + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + a_v(i, j) = i * i; + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < K; j++) { + a_v(i, j) = j * j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < N; j++) { + c_v(i, j) = i + j; + } + } + for (int i = 0; i < M; i++) { + for (int j = 0; j < K; j++) { + d_v(i, j) = i * j; + } + } + + PaddedBuffer z_v(M, N, K); + PaddedBuffer z_ref(M, N, K); + for (int m = 0; m < M; m++) { + for (int n = 0; n < N; n++) { + for (int k = 0; k < K; k++) { + z_ref(m, n, k) = a_v(m, n) * b_v(n, k) * 2 + c_v(m, n) * d_v(m, k); + } + } + } + + SimpleIREvaluator eval(stmt, a_buf, b_buf, c_buf, d_buf, z); + eval(a_v, b_v, c_v, d_v, z_v); + ExpectAllNear(z_v, z_ref, 1e-5); + } + + if (inline_order.size() == 2) { + Tensor* z2 = Compute( + "z", + {{M, "m3"}, {N, "n3"}, {K, "k3"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) * b_buf(n, k) + + (c_buf(m, n) * d_buf(m, k) + a_buf(m, n) * b_buf(n, k)); + }); + Schedule sch2({z2}); + Stmt* stmt2 = sch2.Lower(); + + std::ostringstream oss2; + oss2 << stmt2; + std::string str2 = remove_space(oss2.str()); + + ASSERT_EQ(str1, str2); + ASSERT_GT(str1.size(), 100); + } +} + +void testScheduleInlineFunc01() { + InlineFunc01Helper({"x", "y"}); + InlineFunc01Helper({"y", "x"}); + InlineFunc01Helper({"x"}); + InlineFunc01Helper({"y"}); + InlineFunc01Helper({}); +} + +void testScheduleFuserStyle() { + KernelScope kernel_scope; + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + Buffer a_buf(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + Tensor* b = + Compute("f", {{kTotalSize, "i"}}, [&](const std::vector& axes) { + return a_buf(axes[0]) + 11.0f; + }); + + Tensor* c = + Compute("g", {{kTotalSize, "i"}}, [&](const std::vector& axes) { + return b->call(axes[0]) + 1.0f; + }); + + Schedule sch({b, c}); + Stmt* s = sch.Lower(); + + std::vector a_data(kTotalSize, 7.0f); + std::vector b_data(kTotalSize, 0.0f); + std::vector c_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, a_buf, b, c)(a_data, b_data, c_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(b_data[i], 18.0f); + ASSERT_EQ(c_data[i], 19.0f); + } +} + +void testScheduleFuserThreeArg() { + KernelScope kernel_scope; + const int kVectorSize = 8; + const int kVectorCount = 128; + const int kTotalSize = kVectorSize * kVectorCount; + + Buffer a(VarHandle("A", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer b(VarHandle("B", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer c(VarHandle("C", kHandle), kFloat, {ExprHandle(kTotalSize)}); + Buffer d(VarHandle("D", kHandle), kFloat, {ExprHandle(kTotalSize)}); + + Tensor* e = Compute( + "e", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return a(i) + b(i); }); + Tensor* f = Compute( + "f", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return (*e)(i) + c(i); }); + Tensor* g = Compute( + "g", {{kTotalSize, "i"}}, [&](const VarHandle& i) { return (*f)(i) + d(i); }); + + Schedule sch({g}); + e->ComputeInline(); + f->ComputeInline(); + Stmt* s = sch.Lower(); + + std::vector a_data(kTotalSize, 1.0f); + std::vector b_data(kTotalSize, 2.0f); + std::vector c_data(kTotalSize, 3.0f); + std::vector d_data(kTotalSize, 4.0f); + std::vector g_data(kTotalSize, 0.0f); + SimpleIREvaluator(s, a, b, c, d, g)(a_data, b_data, c_data, d_data, g_data); + + for (int i = 0; i < kTotalSize; i++) { + ASSERT_EQ(g_data[i], 10.0f); + } +} + +void testScheduleDynamicShape2D() { + KernelScope kernel_scope; + auto testWithSize = [](int32_t M, int32_t N) { + VarHandle m("m", kInt); + VarHandle n("n", kInt); + Buffer a(VarHandle("a", kHandle), kFloat, {m, n}); + Buffer b(VarHandle("b", kHandle), kFloat, {m, n}); + Tensor* c = + Compute("c", {{m, "m"}, {n, "n"}}, [&](const VarHandle& i, const VarHandle& j) { + return a(i, j) + b(i, j); + }); + auto sch = Schedule::make({c}); + Stmt* s = sch.Lower(); + SimpleIREvaluator cg(s, {a, b, c, m, n}); + std::vector aData(M * N, 1.0f); + std::vector bData(M * N, 2.0f); + std::vector cData(M * N, 0.0f); + cg.call({aData, bData, cData, M, N}); + ExpectAllNear(cData, std::vector(M * N, 3.0f), 1e-7); + }; + testWithSize(1, 8); + testWithSize(16, 32); + testWithSize(37, 11); +} + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_type.cpp b/test/cpp/tensorexpr/test_type.cpp new file mode 100644 index 0000000000000..e3f892bc2211b --- /dev/null +++ b/test/cpp/tensorexpr/test_type.cpp @@ -0,0 +1,124 @@ +#include "test/cpp/tensorexpr/test_base.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +void testTypeTest01() { + KernelScope kernel_scope; + { + Dtype dt1 = kInt; + EXPECT_EQ(dt1, kInt); + } + { + Dtype dt2_a(kInt, 8); + Dtype dt2_b(kInt, 4); + Dtype dt2_c(ScalarType::Int, 8); + EXPECT_EQ(dt2_a, dt2_c); + EXPECT_NE(dt2_a, dt2_b); + } + { + EXPECT_EQ(kInt, ToDtype()); + EXPECT_EQ(kFloat, ToDtype()); + EXPECT_EQ(kByte, ToDtype()); + EXPECT_EQ(kChar, ToDtype()); + EXPECT_EQ(kShort, ToDtype()); + EXPECT_EQ(kLong, ToDtype()); + EXPECT_EQ(kHalf, ToDtype()); + EXPECT_EQ(kDouble, ToDtype()); + EXPECT_EQ(kBool, ToDtype()); + } + { + Dtype int32x8(kInt, 8); + Dtype float32x8(kFloat, 8); + EXPECT_NE(int32x8, float32x8); + EXPECT_EQ(float32x8, BinaryOpDtype(int32x8, float32x8)); + EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, int32x8)); + EXPECT_EQ(int32x8, BinaryOpDtype(int32x8, int32x8)); + EXPECT_EQ(float32x8, BinaryOpDtype(float32x8, float32x8)); + } +} + +void testTypePropagation() { + // Same types: + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kFloat); + ExprHandle body = FloatImm::make(2.f) + + (x * FloatImm::make(3.f) + FloatImm::make(4.f) * y); + ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body); + ExprHandle e2 = Let::make(y, FloatImm::make(6.f), e1); + EXPECT_EQ(e2.dtype(), kFloat); + } + // Int to bigger int: + { + KernelScope kernel_scope; + VarHandle x("x", kShort); + VarHandle y("y", kLong); + ExprHandle body = + ShortImm::make(2.f) + (x * ShortImm::make(3) + ShortImm::make(4) * y); + ExprHandle e1 = Let::make(x, ShortImm::make(3), body); + ExprHandle e2 = Let::make(y, LongImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kLong); + } + // Float to bigger float: + { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + VarHandle y("y", kDouble); + ExprHandle body = + HalfImm::make(2.f) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ExprHandle e1 = Let::make(x, HalfImm::make(3), body); + ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kDouble); + } + // Int to Float: + { + KernelScope kernel_scope; + VarHandle x("x", kFloat); + VarHandle y("y", kInt); + ExprHandle body = + IntImm::make(2) + (x * IntImm::make(3) + IntImm::make(4) * y); + ExprHandle e1 = Let::make(x, FloatImm::make(3.f), body); + ExprHandle e2 = Let::make(y, IntImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kFloat); + } + // Smaller float, bigger Int: + { + KernelScope kernel_scope; + VarHandle x("x", kHalf); + VarHandle y("y", kLong); + ExprHandle body = + HalfImm::make(2) + (x * HalfImm::make(3) + HalfImm::make(4) * y); + ExprHandle e1 = Let::make(x, HalfImm::make(3), body); + ExprHandle e2 = Let::make(y, LongImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kHalf); + } + // Bigger float, smaller Int: + { + KernelScope kernel_scope; + VarHandle x("x", kChar); + VarHandle y("y", kDouble); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ExprHandle e1 = Let::make(x, CharImm::make(3), body); + ExprHandle e2 = Let::make(y, DoubleImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kDouble); + } + // Sign change char/byte upgrades to short: + { + KernelScope kernel_scope; + VarHandle x("x", kChar); + VarHandle y("y", kByte); + ExprHandle body = + CharImm::make(2) + (x * CharImm::make(3) + CharImm::make(4) * y); + ExprHandle e1 = Let::make(x, CharImm::make(3), body); + ExprHandle e2 = Let::make(y, ByteImm::make(6), e1); + EXPECT_EQ(e2.dtype(), kShort); + } +} +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/test_utils.h b/test/cpp/tensorexpr/test_utils.h new file mode 100644 index 0000000000000..1468f03b478b0 --- /dev/null +++ b/test/cpp/tensorexpr/test_utils.h @@ -0,0 +1,14 @@ +#pragma once + +#include +#include + +#include "test/cpp/tensorexpr/test_base.h" +#include "torch/csrc/jit/testing/file_check.h" + +namespace torch { +namespace jit { +using namespace torch::jit::tensorexpr; + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/tests.h b/test/cpp/tensorexpr/tests.h new file mode 100644 index 0000000000000..ef75bc4ae2404 --- /dev/null +++ b/test/cpp/tensorexpr/tests.h @@ -0,0 +1,183 @@ +#pragma once + +/** + * See README.md for instructions on how to add a new test. + */ +#include +#include + +namespace torch { +namespace jit { +#define TH_FORALL_TESTS(_) \ + _(ExprBasicValueTest) \ + _(ExprBasicValueTest02) \ + _(ExprLetTest01) \ + _(ExprLetStmtTest01) \ + _(ExprLetTest02) \ + _(ExprIntTest) \ + _(ExprFloatTest) \ + _(ExprByteTest) \ + _(ExprCharTest) \ + _(ExprShortTest) \ + _(ExprLongTest) \ + _(ExprHalfTest) \ + _(ExprDoubleTest) \ + _(ExprVectorAdd01) \ + _(ExprCompareSelectEQ) \ + _(ExprSubstitute01) \ + _(ExprMath01) \ + _(ExprUnaryMath01) \ + _(ExprBinaryMath01) \ + _(ExprDynamicShapeAdd) \ + _(IRPrinterBasicValueTest) \ + _(IRPrinterBasicValueTest02) \ + _(IRPrinterLetTest01) \ + _(IRPrinterLetTest02) \ + _(IRPrinterCastTest) \ + _(ExprSimple01) \ + _(ExprLower01) \ + _(ExprSimple02) \ + _(ExprSplitWithTailNone) \ + _(ExprSplitWithMask01) \ + _(ScheduleBroadcastAddBuffer) \ + _(ScheduleFunctionCall01) \ + _(ScheduleInlineFunc01) \ + _(ScheduleFuserStyle) \ + _(ScheduleFuserThreeArg) \ + _(ScheduleDynamicShape2D) \ + _(TypeTest01) \ + _(TypePropagation) \ + _(Cond01) \ + _(IfThenElse01) \ + _(IfThenElse02) \ + _(ATen_cast_Float) \ + _(ATennegInt) \ + _(ATennegFloat) \ + _(ATenaddInt) \ + _(ATenaddFloat) \ + _(ATensubInt) \ + _(ATensubFloat) \ + _(ATenlerp) \ + _(ATenaddcmulInt) \ + _(ATenaddcmulFloat) \ + _(ATenmulInt) \ + _(ATenmulFloat) \ + _(ATendivInt) \ + _(ATendivFloat) \ + _(ATenmaxInt) \ + _(ATenmaxFloat) \ + _(ATenminInt) \ + _(ATenminFloat) \ + _(ATen_sigmoid_backward) \ + _(ATen_tanh_backward) \ + _(ATenreciprocal) \ + _(ATenreluInt) \ + _(ATenreluFloat) \ + _(ATenlogFloat) \ + _(ATenlog10Float) \ + _(ATenlog2Float) \ + _(ATenexpFloat) \ + _(ATenerfFloat) \ + _(ATencosFloat) \ + _(ATeneqInt) \ + _(ATengeInt) \ + _(ATengtInt) \ + _(ATenleInt) \ + _(ATenltInt) + +#define TH_FORALL_TESTS_LLVM(_) \ + _(LLVMByteImmTest) \ + _(LLVMCharImmTest) \ + _(LLVMShortImmTest) \ + _(LLVMIntImmTest) \ + _(LLVMLongImmTest) \ + _(LLVMFloatImmTest) \ + _(LLVMDoubleImmTest) \ + _(LLVMHalfImmTest) \ + _(LLVMByteAddTest) \ + _(LLVMCharAddTest) \ + _(LLVMShortAddTest) \ + _(LLVMIntAddTest) \ + _(LLVMLongAddTest) \ + _(LLVMFloatAddTest) \ + _(LLVMDoubleAddTest) \ + _(LLVMHalfAddTest) \ + _(LLVMByteSubTest) \ + _(LLVMCharSubTest) \ + _(LLVMShortSubTest) \ + _(LLVMIntSubTest) \ + _(LLVMLongSubTest) \ + _(LLVMFloatSubTest) \ + _(LLVMDoubleSubTest) \ + _(LLVMHalfSubTest) \ + _(LLVMByteMulTest) \ + _(LLVMCharMulTest) \ + _(LLVMShortMulTest) \ + _(LLVMIntMulTest) \ + _(LLVMLongMulTest) \ + _(LLVMFloatMulTest) \ + _(LLVMDoubleMulTest) \ + _(LLVMHalfMulTest) \ + _(LLVMByteDivTest) \ + _(LLVMCharDivTest) \ + _(LLVMShortDivTest) \ + _(LLVMIntDivTest) \ + _(LLVMLongDivTest) \ + _(LLVMFloatDivTest) \ + _(LLVMDoubleDivTest) \ + _(LLVMHalfDivTest) \ + _(LLVMIntToFloatCastTest) \ + _(LLVMFloatToIntCastTest) \ + _(LLVMIntToLongCastTest) \ + _(LLVMByteToCharCastTest) \ + _(LLVMHalfToLongCastTest) \ + _(LLVMByteToDoubleCastTest) \ + _(LLVMLetTest01) \ + _(LLVMLetTest02) \ + _(LLVMLetTestMultitype) \ + _(LLVMBufferTest) \ + _(LLVMBlockTest) \ + _(LLVMLoadStoreTest) \ + _(LLVMVecLoadStoreTest) \ + _(LLVMMemcpyTest) \ + _(LLVMBzeroTest) \ + _(LLVMElemwiseAdd) \ + _(LLVMElemwiseAddFloat) \ + _(LLVMElemwiseLog10Float) \ + _(LLVMElemwiseMaxInt) \ + _(LLVMElemwiseMinInt) \ + _(LLVMElemwiseMaxNumFloat) \ + _(LLVMElemwiseMaxNumNaNFloat) \ + _(LLVMElemwiseMinNumFloat) \ + _(LLVMElemwiseMinNumNaNFloat) \ + _(LLVMCompareSelectIntEQ) \ + _(LLVMCompareSelectFloatEQ) \ + _(LLVMStoreFloat) \ + _(LLVMSimpleMath01) \ + _(LLVMComputeMul) \ + _(LLVMBroadcastAdd) \ + _(LLVMDynamicShapeAdd) \ + _(LLVMBindDynamicShapeAdd) \ + _(LLVMTensorDynamicShapeAdd) \ + _(LLVMDynamicShape2D) \ + _(LLVMIfThenElseTest) + +#define TH_FORALL_TESTS_CUDA(_) \ + _(CudaTestVectorAdd01) \ + _(CudaTestVectorAdd02) \ + _(CudaDynamicShape2D) \ + _(CudaTestRand01) \ + _(CudaDynamicShapeSplit) + +#define DECLARE_TENSOREXPR_TEST(name) void test##name(); +TH_FORALL_TESTS(DECLARE_TENSOREXPR_TEST) +#ifdef ENABLE_LLVM +TH_FORALL_TESTS_LLVM(DECLARE_TENSOREXPR_TEST) +#endif +#ifdef USE_CUDA +TH_FORALL_TESTS_CUDA(DECLARE_TENSOREXPR_TEST) +#endif +#undef DECLARE_TENSOREXPR_TEST + +} // namespace jit +} // namespace torch diff --git a/test/cpp/tensorexpr/tests_setup.py b/test/cpp/tensorexpr/tests_setup.py new file mode 100644 index 0000000000000..68871d1c21d21 --- /dev/null +++ b/test/cpp/tensorexpr/tests_setup.py @@ -0,0 +1,88 @@ +import sys +import os +import torch + + +class Setup(object): + def setup(self): + raise NotImplementedError() + + def shutdown(self): + raise NotImplementedError() + + +class FileSetup(object): + path = None + + def shutdown(self): + if os.path.exists(self.path): + os.remove(self.path) + pass + + +class EvalModeForLoadedModule(FileSetup): + path = 'dropout_model.pt' + + def setup(self): + class Model(torch.jit.ScriptModule): + def __init__(self): + super(Model, self).__init__() + self.dropout = torch.nn.Dropout(0.1) + + @torch.jit.script_method + def forward(self, x): + x = self.dropout(x) + return x + + model = Model() + model = model.train() + model.save(self.path) + + +class SerializationInterop(FileSetup): + path = 'ivalue.pt' + + def setup(self): + ones = torch.ones(2, 2) + twos = torch.ones(3, 5) * 2 + + value = (ones, twos) + + torch.save(value, self.path, _use_new_zipfile_serialization=True) + + +# See testTorchSaveError in test/cpp/jit/tests.h for usage +class TorchSaveError(FileSetup): + path = 'eager_value.pt' + + def setup(self): + ones = torch.ones(2, 2) + twos = torch.ones(3, 5) * 2 + + value = (ones, twos) + + torch.save(value, self.path, _use_new_zipfile_serialization=False) + + +tests = [ + EvalModeForLoadedModule(), + SerializationInterop(), + TorchSaveError(), +] + +def setup(): + for test in tests: + test.setup() + + +def shutdown(): + for test in tests: + test.shutdown() + + +if __name__ == "__main__": + command = sys.argv[1] + if command == "setup": + setup() + elif command == "shutdown": + shutdown() diff --git a/test/test_jit.py b/test/test_jit.py index 55a0ecb569963..e157d902a01e7 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -6510,6 +6510,18 @@ def my_slice(x): bailout_graph_str = str(my_slice.graph_for(a)) FileCheck().check_count("prim::BailOut", 1).run(bailout_graph_str) + @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "skip if profiling isn't enabled") + def test_unsqueeze_guard_elimination(self): + @torch.jit.script + def my_unsqueeze(x): + return torch.unsqueeze(x, 0) + torch.unsqueeze(x, 0) + + a = torch.rand(32, 4) + + with enable_profiling_mode(): + my_unsqueeze(a) + bailout_graph_str = str(my_unsqueeze.graph_for(a)) + FileCheck().check_count("prim::BailOut", 2).run(bailout_graph_str) def test_resize_input_ops(self): # resize_ and resize_as resize the input tensor. because our shape analysis diff --git a/test/test_tensorexpr.py b/test/test_tensorexpr.py new file mode 100644 index 0000000000000..3d4304f1d5cf9 --- /dev/null +++ b/test/test_tensorexpr.py @@ -0,0 +1,1164 @@ +import contextlib +import numpy as np +import torch +import torch.nn.functional as F +import unittest + + +@contextlib.contextmanager +def num_profiled_runs(num_runs): + old_num_runs = torch._C._jit_set_num_profiled_runs(num_runs) + try: + yield + finally: + torch._C._jit_set_num_profiled_runs(old_num_runs) + + +class BaseTestClass(unittest.TestCase): + def setUp(self): + # TODO: read the old value and restore it rather than always set to True + # on exit + torch._C._jit_override_can_fuse_on_gpu(False) + + def tearDown(self): + torch._C._jit_override_can_fuse_on_gpu(True) + +class ExecutionCounter(object): + def __init__(self, name): + self.name = name + self.start_value = torch._C._jit_get_trigger_value(self.name) + + def elapsed_value(self): + value = torch._C._jit_get_trigger_value(self.name) + return value - self.start_value + + +class CudaCodeGenCreated(ExecutionCounter): + def __init__(self): + super(CudaCodeGenCreated, self).__init__("cuda_codegen_created") + + +class CudaCodeGenExecuted(ExecutionCounter): + def __init__(self): + super(CudaCodeGenExecuted, self).__init__("cuda_codegen_executed") + + +class LLVMCodeGenCreated(ExecutionCounter): + def __init__(self): + super(LLVMCodeGenCreated, self).__init__("llvm_codegen_created") + + +class LLVMCodeGenExecuted(ExecutionCounter): + def __init__(self): + super(LLVMCodeGenExecuted, self).__init__("llvm_codegen_executed") + + +class SimpleIREvalExecuted(ExecutionCounter): + def __init__(self): + super(SimpleIREvalExecuted, self).__init__("simple_ir_eval_executed") + +class TestTensorExprFuser(BaseTestClass): + def test_easy(self): + def easy(x, y): + aaa = torch.add(x, y) + return aaa + + traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) + + a = torch.rand(1024) + b = torch.rand(1024) + x = traced(a, b) + np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) + + + def test_three_arg(self): + llvm_executed = LLVMCodeGenExecuted() + simple_ir_eval_executed = SimpleIREvalExecuted() + + def easy(x, y, z): + aaa = torch.add(x, y) + bbb = torch.add(aaa, z) + return bbb + + traced = torch.jit.trace( + easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) + ) + + a = torch.rand(1024) + b = torch.rand(1024) + c = torch.rand(1024) + x = traced(a, b, c) + npr = a.numpy() + b.numpy() + c.numpy() + np.testing.assert_allclose(npr, x.numpy()) + assert ( + llvm_executed.elapsed_value() >= 1 + or simple_ir_eval_executed.elapsed_value() >= 1 + ) + + + def test_four_arg(self): + def run_addcmul(x, y, z, w): + c = torch.addcmul(torch.add(x, y), z, w) + return c + + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] + for dev in device_options: + rand_a = torch.rand(1024, dtype=torch.float, device=dev) + rand_b = torch.rand(1024, dtype=torch.float, device=dev) + rand_c = torch.rand(1024, dtype=torch.float, device=dev) + rand_d = torch.rand(1024, dtype=torch.float, device=dev) + + traced = torch.jit.trace( + run_addcmul, + ( + torch.zeros(1024, dtype=torch.float, device=dev), + torch.zeros(1024, dtype=torch.float, device=dev), + torch.zeros(1024, dtype=torch.float, device=dev), + torch.zeros(1024, dtype=torch.float, device=dev), + ), + ) + + x = traced(rand_a, rand_b, rand_c, rand_d) + y = run_addcmul(rand_a, rand_b, rand_c, rand_d) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6) + + + def test_three_arg_cuda(self): + if not torch.cuda.is_available(): + return + cuda_cg_executed = CudaCodeGenExecuted() + cuda_cg_created = CudaCodeGenCreated() + + def test(x, y, z): + aaa = torch.add(x, y) + bbb = torch.add(aaa, z) + return bbb + + M = 32 + N = 32 + traced = torch.jit.trace( + test, + ( + torch.rand(M, N, device="cuda"), + torch.rand(M, N, device="cuda"), + torch.rand(M, N, device="cuda"), + ), + ) + + a = torch.rand(M, N, device="cuda") + b = torch.rand(M, N, device="cuda") + c = torch.rand(M, N, device="cuda") + x = traced(a, b, c) + npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() + np.testing.assert_allclose(npr, x.cpu().numpy()) + assert cuda_cg_executed.elapsed_value() >= 1 + assert cuda_cg_created.elapsed_value() >= 1 + + + def test_broadcast_cuda(self): + if not torch.cuda.is_available(): + return + + def test_body(M, N, L, K): + if not torch.cuda.is_available(): + return + cuda_cg_executed = CudaCodeGenExecuted() + cuda_cg_created = CudaCodeGenCreated() + + def test(x, y, z): + v1 = torch.add(x, y) + v2 = torch.add(v1, z) + return v2 + + a_shape = [M, N] + b_shape = [L, M, 1] + c_shape = [K, L, 1, 1] + traced = torch.jit.trace( + test, + ( + torch.rand(*a_shape, device="cuda"), + torch.rand(*b_shape, device="cuda"), + torch.rand(*c_shape, device="cuda"), + ), + ) + + a = torch.rand(*a_shape, device="cuda") + b = torch.rand(*b_shape, device="cuda") + c = torch.rand(*c_shape, device="cuda") + x = traced(a, b, c) + npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy() + np.testing.assert_allclose(npr, x.cpu().numpy()) + assert cuda_cg_executed.elapsed_value() >= 1 + assert cuda_cg_created.elapsed_value() >= 1 + + test_configs = [[36, 17, 63, 33], [32, 32, 32, 32]] + for test_config in test_configs: + test_body(*test_config) + + + def test_all_combos(self): + def easy(x, y, z): + a = torch.add(x, y) + b = torch.add(a, z) + c = torch.add(x, b) + d = torch.add(c, a) + return d + + def np_easy(x, y, z): + a = x + y + b = a + z + c = x + b + d = c + a + return d + + traced = torch.jit.trace( + easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) + ) + + a = torch.rand(1024) + b = torch.rand(1024) + c = torch.rand(1024) + x = traced(a, b, c) + npr = np_easy(a.numpy(), b.numpy(), c.numpy()) + np.testing.assert_allclose(npr, x.numpy()) + + + def test_rank_two(self): + def easy(x, y, z): + a = torch.add(x, y) + b = torch.add(a, z) + c = torch.add(x, b) + d = torch.add(c, a) + return d + + def np_easy(x, y, z): + a = x + y + b = a + z + c = x + b + d = c + a + return d + + shape = 32, 32 + traced = torch.jit.trace( + easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape)) + ) + + a = torch.rand(shape) + b = torch.rand(shape) + c = torch.rand(shape) + x = traced(a, b, c) + npr = np_easy(a.numpy(), b.numpy(), c.numpy()) + np.testing.assert_allclose(npr, x.numpy()) + + def test_matmul(self): + llvm = LLVMCodeGenExecuted() + def easy(x, y): + aaa, bbb = torch.chunk(y, 2) + y = torch.cat([aaa, bbb], dim=0) + aaa = torch.matmul(x, y) * 3 + return aaa + + shape = (128,128) + a = torch.rand(shape) + b = torch.rand(shape) + traced = torch.jit.trace( + easy, (a, b) + ) + + x = traced(a, b) + y = 3 * (a @ b) + np.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-5, atol=1e-3) + assert llvm.elapsed_value() == 1 + + def test_broadcast(self): + def easy(x, y, z): + a = torch.add(x, y) + b = torch.add(a, z) + return b + + def np_easy(x, y, z): + a = x + y + b = a + z + return b + + N = 32 + traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N))) + + a = torch.rand(N, N) + b = torch.rand(N) + c = torch.rand(N, N) + x = traced(a, b, c) + npr = np_easy(a.numpy(), b.numpy(), c.numpy()) + np.testing.assert_allclose(npr, x.numpy()) + + + def test_broadcast_2(self): + zero = torch.tensor([0.0], dtype=torch.float) + + def foo(x, y, z): + aaa = torch.add(x, y) + bbb = torch.add(zero, aaa) + return torch.add(bbb, z) + + def foo_np(x, y, z): + a = x + y + b = zero.numpy() + a + return b + z + + x = torch.rand(3, 4) + y = torch.ones(3, 1) + z = torch.rand(4) + traced = torch.jit.trace(foo, (x, y, z)) + + r = traced(x, y, z) + rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) + np.testing.assert_allclose(r, rnp) + + + def test_broadcast_big2(self): + zero = torch.tensor([0.0], dtype=torch.float) + + def foo(x, y, z): + aaa = torch.add(x, y) + bbb = torch.add(zero, aaa) + return torch.add(bbb, z) + + def foo_np(x, y, z): + a = x + y + b = zero.numpy() + a + return b + z + + x = torch.rand(32, 1024) + y = torch.ones(32, 1) + z = torch.rand(1024) + traced = torch.jit.trace(foo, (x, y, z)) + + r = traced(x, y, z) + rnp = foo_np(x.numpy(), y.numpy(), z.numpy()) + np.testing.assert_allclose(r, rnp) + + + def test_alpha(self): + def alpha(x): + aaa = torch.add(x, x, alpha=2.0) + return aaa + + traced = torch.jit.trace(alpha, (torch.tensor([1.0]))) + + a = torch.tensor([1.0]) + x = traced(a) + np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy()) + + + def test_constant(self): + def constant(x): + bbb = torch.tensor([1.0]) + aaa = torch.add(x, bbb) + return aaa + + traced = torch.jit.trace(constant, (torch.tensor([1.0]))) + + a = torch.tensor([1.0]) + x = traced(a) + np.testing.assert_allclose(a.numpy() + 1.0, x.numpy()) + + + def test_add_sub(self): + def easy(x, y, z): + aaa = torch.add(x, y) + bbb = torch.sub(aaa, z) + return bbb + + traced = torch.jit.trace( + easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024)) + ) + + a = torch.rand(1024) + b = torch.rand(1024) + c = torch.rand(1024) + x = traced(a, b, c) + np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy()) + + + def test_promotion(self): + def easy(x, y): + aaa = torch.add(x, y) + return aaa + + traced = torch.jit.trace( + easy, + (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)), + ) + + a = torch.zeros(1024, dtype=torch.int32) + b = torch.rand(1024, dtype=torch.float32) + x = traced(a, b) + np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy()) + + + def test_eq(self): + def easy(x, y): + c = torch.eq(x, y) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) + a = torch.zeros(1024, dtype=torch.int32) + b = torch.zeros(1024, dtype=torch.int32) + x = traced(a, b) + np.testing.assert_allclose(np.ones(1024), x.numpy()) + + + def test_ne(self): + def easy(x, y): + c = torch.ne(x, y) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) + a = torch.zeros(1024, dtype=torch.int32) + b = torch.ones(1024, dtype=torch.int32) + x = traced(a, b) + np.testing.assert_allclose(np.ones(1024), x.numpy()) + + + def test_ge(self): + def easy(x, y): + c = torch.ge(x, y) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) + aa = np.array(1024, dtype=int) + aa.fill(5) + a = torch.from_numpy(aa) + b = torch.zeros(1024, dtype=torch.int32) + x = traced(a, b) + np.testing.assert_allclose(np.ones(1024), x.numpy()) + + + def test_gt(self): + def easy(x, y): + c = torch.gt(x, y) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) + a = torch.ones(1024, dtype=torch.int32) + b = torch.zeros(1024, dtype=torch.int32) + x = traced(a, b) + np.testing.assert_allclose(np.ones(1024), x.numpy()) + + + def test_le(self): + def easy(x, y): + c = torch.le(x, y) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024))) + aa = np.array(1024, dtype=int) + aa.fill(5) + a = torch.from_numpy(aa) + b = torch.zeros(1024, dtype=torch.int32) + x = traced(a, b) + np.testing.assert_allclose(np.zeros(1024), x.numpy()) + + + def test_lt(self): + def easy(x, y): + c = torch.lt(x, y) + return c + + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + for dev in device_options: + traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev))) + a = torch.ones(1024, dtype=torch.int32, device=dev) + b = torch.zeros(1024, dtype=torch.int32, device=dev) + x = traced(a, b) + np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy()) + + + def test_min_max(self): + def test(x, y): + return torch.max(torch.min(x, y), torch.tensor([4.0])) + + traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024))) + a = 8.0 * torch.rand(1024) + b = 8.0 * torch.rand(1024) + np.testing.assert_allclose( + traced(a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0]) + ) + + + def test_clamp(self): + def test(x): + return torch.clamp(x + 3.0, 0.0, 6.0) + + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + + for dev in device_options: + traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) + a = 20.0 * torch.rand(1024, device=dev) - 10.0 + an = a.cpu().numpy() + np.testing.assert_allclose(traced(a).cpu(), np.clip(an + 3.0, 0.0, 6.0)) + + def test_relu(self): + def test(x): + return torch.clamp(F.relu(x), 0, 0.5) + + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"] + for dev in device_options: + traced = torch.jit.trace(test, (torch.zeros(1024, device=dev))) + a = 20.0 * torch.rand(1024, device=dev) - 10.0 + an = a.cpu().numpy() + np.testing.assert_allclose(traced(a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5)) + + + def test_reps(self): + def easy(x, y): + c = torch.add(x, y) + return c + + traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024))) + + for _ in range(32): + a = torch.ones(1024) + b = torch.zeros(1024) + x = traced(a, b) + np.testing.assert_allclose(np.ones(1024), x.numpy()) + + + def test_add_const_rhs(self): + def test(x): + return x + 3.0 + + traced = torch.jit.trace(test, torch.rand(4)) + x = torch.rand(4) + y = traced(x) + np.testing.assert_allclose(x.numpy() + 3.0, y.numpy()) + + + def test_int_output(self): + def test(x, y, z): + return x * y * z + + xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)] + x, y, z = xs + xn, yn, zn = [t.numpy() for t in xs] + traced = torch.jit.trace(test, (x, y, z)) + res = traced(x, y, z) + np.testing.assert_allclose(xn * yn * zn, res.numpy()) + + def test_binary_ops(self): + def test_atan2(x, y): + c = torch.atan2(torch.add(x, y), y) + return c + + def test_gt(x, y): + c = torch.gt(torch.add(x, y), y) + return c + + def test_ge(x, y): + c = torch.ge(torch.add(x, y), y) + return c + + def test_lt(x, y): + c = torch.lt(torch.add(x, y), y) + return c + + def test_le(x, y): + c = torch.le(torch.add(x, y), y) + return c + + def test_lerp(x, y): + c = torch.lerp(torch.add(x, 1), x, 2.0) + return c + + def test_mul(x, y): + c = torch.mul(torch.add(x, y), y) + return c + + def test_ne(x, y): + c = torch.ne(torch.add(x, y), y) + return c + + def test_div(x, y): + c = torch.div(torch.add(x, y), 2) + return c + + def test_eq(x, y): + c = torch.eq(torch.add(x, y), y) + return c + + def test_fmod(x, y): + c = torch.fmod(torch.add(x, y), 2) + return c + + def test_sub(x, y): + c = torch.sub(torch.add(x, y), x) + return c + + def test_remainder(x, y): + c = torch.remainder(torch.add(x, y), 3.0) + return c + + def test_pow(x, y): + c = torch.pow(torch.add(x, y), 2.0) + return c + + def test_sigmoid_backward(x, y): + x_2 = torch.mul(x, x) + c = torch.sigmoid(x_2) + torch.autograd.backward(c, y) + return c.detach() + + def test_tanh_backward(x, y): + x_2 = torch.mul(x, x) + c = torch.tanh(x_2) + torch.autograd.backward(c, y) + return c.detach() + + def test_type_as(x, y): + return x.type_as(torch.add(x, y)) + + fns = { + test_atan2, + test_gt, + test_ge, + test_lt, + test_le, + test_lerp, + test_mul, + test_ne, + test_div, + test_eq, + test_fmod, + test_sub, + test_remainder, + test_pow, + # to fix the backward path, need script instead of trace + # test_sigmoid_backward, + # test_tanh_backward, + test_type_as, + } + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] + for torch_fn in fns: + for dev in device_options: + rand_a = torch.rand(1024, device=dev) + rand_b = torch.rand(1024, device=dev) + in1 = 20 * torch.rand(1024, device=dev) + in2 = 20 * torch.rand(1024, device=dev) + traced = torch.jit.trace(torch_fn, (in1, in2)) + x = traced(rand_a, rand_b) + y = torch_fn(rand_a, rand_b) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) + + def test_unary_ops(self): + def test_cast_float(x, y): + c = torch.ops.aten._cast_Float(torch.add(x, y)) + return c + + def test_round(x, y): + c = torch.round(torch.add(x, y)) + return c + + def test_sin(x, y): + c = torch.sin(torch.add(x, y)) + return c + + def test_asin(x, y): + c = torch.asin(torch.add(x, y)) + return c + + def test_sinh(x, y): + c = torch.sinh(torch.add(x, y)) + return c + + def test_cos(x, y): + c = torch.cos(torch.add(x, y)) + return c + + def test_acos(x, y): + c = torch.acos(torch.add(x, y)) + return c + + def test_cosh(x, y): + c = torch.cosh(torch.add(x, y)) + return c + + def test_tan(x, y): + c = torch.tan(torch.add(x, y)) + return c + + def test_atan(x, y): + c = torch.atan(torch.add(x, y)) + return c + + def test_tanh(x, y): + c = torch.tanh(torch.add(x, y)) + return c + + def test_sqrt(x, y): + c = torch.sqrt(torch.add(x, y)) + return c + + def test_rsqrt(x, y): + c = torch.rsqrt(torch.add(x, y)) + return c + + def test_floor(x, y): + c = torch.floor(torch.add(x, y)) + return c + + def test_ceil(x, y): + c = torch.ceil(torch.add(x, y)) + return c + + def test_trunc(x, y): + c = torch.trunc(torch.add(x, y)) + return c + + def test_abs(x, y): + c = torch.abs(torch.add(x, y)) + return c + + def test_log(x, y): + c = torch.log(torch.add(x, y)) + return c + + def test_log2(x, y): + c = torch.log2(torch.add(x, y)) + return c + + def test_log10(x, y): + c = torch.log10(torch.add(x, y)) + return c + + def test_log1p(x, y): + c = torch.log1p(torch.add(x, y)) + return c + + def test_rqrt(x, y): + c = torch.rsqrt(torch.add(x, y)) + return c + + def test_erf(x, y): + c = torch.erf(torch.add(x, y)) + return c + + def test_exp(x, y): + c = torch.exp(torch.add(x, y)) + return c + + def test_expm1(x, y): + c = torch.expm1(torch.add(x, y)) + return c + + def test_erfc(x, y): + c = torch.erfc(torch.add(x, y)) + return c + + def test_frac(x, y): + c = torch.frac(torch.add(x, y)) + return c + + def test_lgamma(x, y): + c = torch.lgamma(torch.add(x, y)) + return c + + def test_sigmoid(x, y): + c = torch.sigmoid(torch.add(x, y)) + return c + + def test_reciprocal(x, y): + c = torch.reciprocal(torch.add(x, y)) + return c + + def test_neg(x, y): + c = torch.neg(torch.add(x, y)) + return c + + def test_relu(x, y): + c = torch.relu(torch.add(x, y)) + return c + + def test_threshold(x, y): + c = F.threshold(torch.add(x, y), 0.5, 10) + return c + + fns = { + test_round, + test_sin, + test_asin, + test_sinh, + test_cos, + test_acos, + test_cosh, + test_tan, + test_atan, + test_tanh, + test_sqrt, + test_floor, + test_ceil, + test_trunc, + test_abs, + test_log, + test_log2, + test_log10, + test_log1p, + test_rsqrt, + test_exp, + test_expm1, + test_erf, + test_erfc, + test_frac, + test_lgamma, + test_sigmoid, + test_reciprocal, + test_threshold, + test_neg, + test_relu, + } + device_options = ["cpu", "cuda"] if torch.cuda.is_available() else ['cpu'] + + for torch_fn in fns: + for dev in device_options: + rand_a = torch.rand(1024, device=dev) + rand_b = torch.rand(1024, device=dev) + ins = 20 * torch.rand(1024, device=dev) + cc = np.array(1024, dtype=float) + cc.fill(np.nan) + nans = torch.from_numpy(cc).to(dev) + traced = torch.jit.trace(torch_fn, (ins, ins)) + x = traced(rand_a, rand_b) + y = torch_fn(rand_a, rand_b) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=2e-3) + # nans + traced = torch.jit.trace(torch_fn, (ins, ins)) + x = traced(nans, rand_b) + y = torch_fn(nans, rand_b) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + + + def test_rand_like(self): + devices = ["cuda"] if torch.cuda.is_available() else [] + N = 1 << 16 + + def run_rand_like(x, y): + return torch.rand_like(torch.add(x, y)) + + for device in devices: + x = torch.rand(N, device=device) + traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False) + x_v = traced(x, x) + x_np = x.cpu().numpy() + x1_mean = np.mean(x_np) + x2_mean = np.mean(x_np ** 2) + x3_mean = np.mean(x_np ** 3) + np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2) + np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2) + np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2) + + + def test_nans(self): + def test_max(x, y): + return torch.max(2 * x, 2 * y) + + def test_min(x, y): + return torch.min(2 * x, 2 * y) + + tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1))) + tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1))) + + x = torch.tensor([np.nan]) + y = torch.tensor([1.0]) + + assert not np.isnan(tmin(x, y).item()) + assert np.isnan(tmin(y, x).item()) + assert not np.isnan(tmax(x, y).item()) + assert np.isnan(tmax(y, x).item()) + + + def test_remainder(self): + def run_remainder(x, y): + c = torch.remainder(torch.add(x, y), x) + return c + + a = torch.rand(1024, dtype=float) + b = torch.rand(1024, dtype=float) + zeros = torch.zeros(1024, dtype=float) + cc = np.array(1024, dtype=float) + cc.fill(np.nan) + nans = torch.from_numpy(cc) + + # random floats + traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) + x = traced(a, b) + y = run_remainder(a, b) + np.testing.assert_allclose(x.numpy(), y.numpy()) + + # div by 0 + traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) + x = traced(zeros, a) + y = run_remainder(zeros, a) + np.testing.assert_allclose(x.numpy(), y.numpy()) + + # numerators and denominatos are nan + traced = torch.jit.trace(run_remainder, (torch.zeros(1024), torch.zeros(1024))) + x = traced(nans, a) + y = run_remainder(nans, a) + np.testing.assert_allclose(x.numpy(), y.numpy()) + + + def test_multioutput(self): + def easy(x): + b = x + 1 + c = b + b + return (b, c) + + traced = torch.jit.trace(easy, (torch.zeros(1024))) + + a = torch.zeros(1024) + b, c = traced(a) + bp = a.numpy() + 1 + cp = bp + bp + np.testing.assert_allclose(b.numpy(), bp) + np.testing.assert_allclose(c.numpy(), cp) + + + def test_chunk(self): + def easy(x): + y = x + 1 + aaa, bbb = torch.chunk(y, 2) + return aaa + bbb + + traced = torch.jit.trace(easy, (torch.zeros(1024, 1024))) + + a = torch.zeros(1024, 1024) + x = traced(a) + npr = a.numpy() + npr2 = npr + 1 + npr_a, npr_b = np.array_split(npr2, 2) + np.testing.assert_allclose(npr_a + npr_b, x.numpy()) + + + def test_cat(self): + def easy(x, y): + a = x + 1 + b = y + 2 + c = torch.cat([a, b], dim=1) + return c + + traced = torch.jit.trace(easy, (torch.zeros(1024, 1024), torch.zeros(1024, 1024))) + + a = torch.zeros(1024, 1024) + x = traced(a, a) + npr = a.numpy() + npr_x = npr + 1 + npr_y = npr + 2 + npr_c = np.concatenate((npr_x, npr_y), axis=1) + np.testing.assert_allclose(npr_c, x.numpy()) + + + def test_scalar(self): + @torch.jit.script + def test_float(x, y, z, a, b): + # type: (Tensor, Tensor, Tensor, float, float) -> Tensor + return torch.add(torch.add(x, y, alpha=a), z, alpha=b) + + @torch.jit.script + def test_int(x, y, z, a, b): + # type: (Tensor, Tensor, Tensor, int, int) -> Tensor + return torch.add(torch.add(x, y, alpha=a), z, alpha=b) + + for test in (test_float, test_int): + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + x, y, z = [torch.rand(4) for i in range(3)] + a, b = 1, 2 + test(x, y, z, a, b) + r = test(x, y, z, a, b) + xn, yn, zn = [t.numpy() for t in (x, y, z)] + np.testing.assert_allclose(r.numpy(), xn + yn * a + zn * b) + assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1 + +# FIXME: Blocked on profiling executor changes +# def test_loop(): +# @torch.jit.script +# def test(x, y, z): +# # type: (Tensor, Tensor, int) -> Tensor +# b = y +# for i in range(0, z): +# a = x + y +# b = b + y +# return b +# +# llvm = LLVMCodeGenExecuted() +# interp = SimpleIREvalExecuted() +# x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4) +# test(x, y, z) +# r = test(x, y, z) +# assert llvm.elapsed_value == 1 or interp.elapsed_value() == 1 + + def test_slice(self): + def easy(x, y): + a = x[0:512:2] + b = y[0:512:2] + return a + b + + traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) + + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + + a = torch.ones(1024, 1024) + x = traced(a, a) + npr = a[0:512:2] + npr = npr + npr + np.testing.assert_allclose(npr.numpy(), x.numpy()) + assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1 + + + @unittest.skip("fails on trunk") + def test_unsqueeze(self): + def easy(x, y): + a = torch.unsqueeze(x, 0) + b = torch.unsqueeze(y, 0) + return a + b + + traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024))) + + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + + a = torch.rand(1024, 1024) + x = traced(a, a) + npr = np.expand_dims(a, 0) + npr = npr + npr + np.testing.assert_allclose(npr, x.numpy()) + assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1 + + + def test_transpose(self): + @torch.jit.script + def test(x, y, z): + return x.transpose(0, 1) + y + z + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + x = torch.rand(4, 5, 2, 3) + y = torch.rand(5, 4, 2, 3) + z = torch.rand(5, 4, 2, 3) + ref = test(x, y, z) + res = test(x, y, z) + np.testing.assert_allclose(ref.numpy(), res.numpy()) + assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1 + + + def test_sliced_stride(self): + @torch.jit.script + def test(x, y, z): + return x + y + z + llvm = LLVMCodeGenExecuted() + interp = SimpleIREvalExecuted() + x = torch.rand(16, 4, 2, 3)[::2] + y = torch.rand(8, 4, 2, 3) + z = torch.rand(8, 4, 2, 3) + ref = test(x, y, z) + res = test(x, y, z) + np.testing.assert_allclose(ref.numpy(), res.numpy()) + assert llvm.elapsed_value() == 1 or interp.elapsed_value() == 1 + + + @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA") + @unittest.skip("dynamic shapes are not quite there yet") + def test_dynamic_shape(self): + with num_profiled_runs(2): + @torch.jit.script + def test(x, y, z): + return x * y * z + cuda = CudaCodeGenCreated() + x, y, z = [torch.rand(4, 8).cuda() for _ in range(3)] + ref = test(x, y, z) + _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)]) + res = test(x, y, z) + np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy()) + assert cuda.elapsed_value() == 1 + + # A wild broadcast appears. + x = torch.rand(4, 8).cuda() + y = torch.rand(1, 8).cuda() + z = torch.rand(4, 1).cuda() + res = test(x, y, z) + xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] + np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) + assert cuda.elapsed_value() == 1 + + # Mismatched shapes shouldn't reach codegen. + x = torch.rand(4, 8).cuda() + y = torch.rand(4, 8).cuda() + z = torch.rand(5, 8).cuda() + try: + res = test(x, y, z) + except RuntimeError as e: + assert "The size of tensor a (4) must match" in e.args[0] + assert cuda.elapsed_value() == 1 + + # Changing a static dimension fails guards. + # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)] + # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)] + # res = test(x, y, z) + # print(test.graph_for(x, y, z)) + # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn) + # assert cuda.elapsed_value() == 1 + + @unittest.skip("guarding on static shapes is not working") + def test_guard_fails(): + @torch.jit.script + def test(x, y, z): + return x * y * z + cuda = CudaCodeGenExecuted() + _ = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 0 + _ = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 1 + _ = test(*[torch.rand(4).cuda() for _ in range(3)]) + assert cuda.elapsed_value() == 2 + _ = test(*[torch.rand(7).cuda() for _ in range(3)]) + print(test.graph_for(*[torch.rand(7).cuda() for _ in range(3)])) + assert cuda.elapsed_value() == 2 + + def test_bitwise_ops(self): + devices = ["cuda", "cpu"] if torch.cuda.is_available() else ["cpu"] + def run_and(x, y): + return x & (x & y) + + def run_xor(x, y): + return x ^ (x ^ y) + + def run_lshift(x, y): + return x & (x << y) + + def run_rshift(x, y): + return x & (x >> y) + + fns = {run_and, run_xor, run_lshift, run_rshift} + + for device in devices: + for fn in fns: + a = torch.ones(128, dtype=torch.int32, device=device) + b = torch.zeros(128, dtype=torch.int32, device=device) + inp = torch.ones(128, dtype=torch.int32, device=device) + traced = torch.jit.trace(fn, (inp, inp)) + x = traced(a, b) + y = fn(a, b) + np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy()) + + def test_where(self): + def run_where(x, y): + return torch.where(torch.gt(x, y), x, y) + + a = torch.rand(1024, dtype=float) + b = torch.rand(1024, dtype=float) + traced = torch.jit.trace(run_where, (torch.zeros(1024), torch.zeros(1024))) + x = traced(a, b) + y = run_where(a, b) + np.testing.assert_allclose(x.numpy(), y.numpy()) + +if __name__ == '__main__': + unittest.main() diff --git a/tools/build_variables.py b/tools/build_variables.py index 74ec8e42fccc7..08c368f5e5c37 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -80,6 +80,12 @@ "torch/csrc/jit/autodiff.cpp", "torch/csrc/jit/attributes.cpp", "torch/csrc/jit/argument_spec.cpp", + "torch/csrc/jit/compiler/src/asmjit_codegen.cc", + "torch/csrc/jit/compiler/src/expr.cc", + "torch/csrc/jit/compiler/src/function.cc", + "torch/csrc/jit/compiler/src/ir_printer.cc", + "torch/csrc/jit/compiler/src/ir_visitor.cc", + "torch/csrc/jit/compiler/src/types.cc", "torch/csrc/jit/constants.cpp", "torch/csrc/jit/custom_class.cpp", "torch/csrc/jit/node_hashing.cpp", diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index e51a96e3e5c5b..530861ff1f9a3 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -54,6 +54,8 @@ #include #include #include +#include +#include #include #include @@ -382,6 +384,52 @@ void initJITBindings(PyObject* module) { } return nullptr; }) + .def( + "_jit_get_trigger_value", + [](const std::string& trigger_name) { + using namespace torch::jit::tensorexpr; + ExecutionTrigger* trigger = + ExecutionTriggerList::GetInstance().FindByName(trigger_name); + return trigger->value(); + }) + .def( + "_jit_get_te_cuda_pointwise_loop_levels", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseLoopLevels(); + }) + .def( + "_jit_set_te_cuda_pointwise_loop_levels", + [](int level) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseLoopLevels() = level; + }) + .def( + "_jit_get_te_cuda_pointwise_block_count", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockCount(); + }) + .def( + "_jit_set_te_cuda_pointwise_block_count", + [](int block_count) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockCount() = block_count; + }) + .def( + "_jit_get_te_cuda_pointwise_block_size", + []() -> int { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockSize(); + }) + .def( + "_jit_set_te_cuda_pointwise_block_size", + [](int block_size) { + using namespace torch::jit::tensorexpr; + return GetTECudaPointwiseBlockSize() = block_size; + }) + .def( + "_jit_set_texpr_fuser_enabled", &torch::jit::tensorexpr::SetTexprFuserEnabled) .def( "_jit_fuser_get_fused_kernel_code", [](Graph& g, std::vector inps) { diff --git a/torch/csrc/jit/passes/guard_elimination.cpp b/torch/csrc/jit/passes/guard_elimination.cpp index 9b75c4b2ab5b5..7ff26e1e2807a 100644 --- a/torch/csrc/jit/passes/guard_elimination.cpp +++ b/torch/csrc/jit/passes/guard_elimination.cpp @@ -1,10 +1,10 @@ -#include #include #include #include #include #include #include +#include #include namespace torch { @@ -12,8 +12,7 @@ namespace jit { struct GuardElimination { GuardElimination(std::shared_ptr graph) - : graph_(std::move(graph)), - aliasDb_(std::make_unique(graph_)) {} + : graph_(std::move(graph)), aliasDb_(std::make_unique(graph_)) {} void run() { const size_t MAX_ATTEMPTS = 5; @@ -123,8 +122,11 @@ struct GuardElimination { auto it = guard; while (it != output) { if (it->kind() != prim::Guard && it->kind() != prim::Constant) { - GRAPH_DEBUG("found an unexpected node ", *it, - " while trying to eliminate ", *guard); + GRAPH_DEBUG( + "found an unexpected node ", + *it, + " while trying to eliminate ", + *guard); return false; } it = it->prev(); @@ -161,7 +163,7 @@ struct GuardElimination { // on inputs to `n`. The invariants must hold, or an input must // be a `prim::Constant` or be of `NumberType` or be included // as an exception in `except` - bool checkInputs(Node *n, const std::unordered_set &except) { + bool checkInputs(Node* n, const std::unordered_set& except) { bool all_inputs_guarded = true; size_t i = 0; for (auto input : n->inputs()) { @@ -174,8 +176,11 @@ struct GuardElimination { input->node()->kind() != prim::Guard || input->type()->expect()); } else { - GRAPH_DEBUG("input ", input->debugName(), " isn't guarded, type ", - *input->type()); + GRAPH_DEBUG( + "input ", + input->debugName(), + " isn't guarded, type ", + *input->type()); all_inputs_guarded = false; break; } @@ -184,7 +189,7 @@ struct GuardElimination { return all_inputs_guarded; } -private: + private: // `removableGuard` relies on the properties checked by `isSummarized()` // and passes shouldn't insert nodes between a guard and its uses that // may alter those properties. @@ -211,108 +216,146 @@ struct GuardElimination { // Guards can be removed if all inputs are guarded and `isSummarized()` // returns // false or inputs are `prim::Constant` - bool removableGuard(Node *n) { - + bool removableGuard(Node* n) { const static auto no_exceptions = std::unordered_set{}; switch (n->kind()) { - case aten::add: - case aten::sub: - case aten::mul: - case aten::div: - case aten::t: - case aten::sigmoid: - case aten::tanh: - case aten::mm: - case aten::min: - case aten::max: - case aten::type_as: - case aten::ge: - case aten::gt: - case aten::lt: - case aten::le: - case aten::eq: - case aten::ne: - case aten::neg: - case prim::ConstantChunk: - case aten::size: - case aten::abs: - case aten::sign: - case aten::pow: - case aten::relu: - case aten::threshold: - case aten::avg_pool2d: - case prim::AutogradAdd: - case prim::AutogradZero: - case aten::rand_like: - case aten::erf: - case aten::erfc: - return checkInputs(n, no_exceptions); - case aten::slice: - return !n->input(0)->type()->expect()->isSummarized() && - // check that the dimension argument is constant - n->input(1)->node()->kind() == prim::Constant && - // the start offset is constant - n->input(2)->node()->kind() == prim::Constant && - // the end offset is constant - n->input(3)->node()->kind() == prim::Constant && - // the stride is constant - n->input(4)->node()->kind() == prim::Constant; - case aten::cat: - // check that the dimension argument is constant - return n->input(1)->node()->kind() == prim::Constant && - n->input(0)->node()->kind() == prim::ListConstruct && - // no extra nodes in between aten::cat and prim::ListConstruct - n->prev() == n->input(0)->node() && - // check the inputs to prim::ListConstruct (not aten::cat) - checkInputs(n->input(0)->node(), no_exceptions); - case aten::clamp: - // the second and third args do not affect shapes - return checkInputs(n, std::unordered_set{1, 2}); - // after some optimizations we might end up with two Guards back-to-back - // which case we can remove the one whose input is also prim::Guard - case aten::_grad_sum_to_size: - // skip checking size argument - if (checkInputs(n, std::unordered_set{1})) { - auto asize = n->input(1)->node(); - if (asize->kind() == prim::Constant) { - return true; - } else if (asize->matches("aten::size(Tensor self) -> int[]")) { - // aten::size is effectively a constant - if (asize->input() - ->type() - ->expect() - ->sizes() - .concrete_sizes()) { + case aten::add: + case aten::sub: + case aten::mul: + case aten::div: + case aten::t: + case aten::sigmoid: + case aten::sin: + case aten::cos: + case aten::tan: + case aten::sinh: + case aten::cosh: + case aten::tanh: + case aten::asin: + case aten::acos: + case aten::atan: + case aten::atan2: + case aten::floor: + case aten::fmod: + case aten::ceil: + case aten::trunc: + case aten::sqrt: + case aten::rsqrt: + case aten::remainder: + case aten::mm: + case aten::matmul: + case aten::min: + case aten::max: + case aten::type_as: + case aten::ge: + case aten::gt: + case aten::lt: + case aten::le: + case aten::eq: + case aten::ne: + case aten::neg: + case prim::ConstantChunk: + case aten::size: + case aten::abs: + case aten::sign: + case aten::pow: + case aten::relu: + case aten::threshold: + case aten::avg_pool2d: + case prim::AutogradAdd: + case prim::AutogradZero: + case aten::rand_like: + case aten::erf: + case aten::erfc: + case aten::exp: + case aten::expm1: + case aten::log: + case aten::log2: + case aten::log10: + case aten::frac: + case aten::lerp: + case aten::lgamma: + case aten::reciprocal: + case aten::addcmul: + case aten::_cast_Float: + case aten::_sigmoid_backward: + case aten::_tanh_backward: + case aten::__and__: + case aten::__xor__: + case aten::__lshift__: + case aten::__rshift__: + case aten::where: + return checkInputs(n, no_exceptions); + case aten::slice: + return !n->input(0)->type()->expect()->isSummarized() && + // check that the dimension argument is constant + n->input(1)->node()->kind() == prim::Constant && + // the start offset is constant + n->input(2)->node()->kind() == prim::Constant && + // the end offset is constant + n->input(3)->node()->kind() == prim::Constant && + // the stride is constant + n->input(4)->node()->kind() == prim::Constant; + case aten::unsqueeze: + // check that the dimension argument is constant + return !n->input(0)->type()->expect()->isSummarized() && + n->input(1)->node()->kind() == prim::Constant; + case aten::cat: + // check that the dimension argument is constant + return n->input(1)->node()->kind() == prim::Constant && + n->input(0)->node()->kind() == prim::ListConstruct && + // no extra nodes in between aten::cat and prim::ListConstruct + n->prev() == n->input(0)->node() && + // check the inputs to prim::ListConstruct (not aten::cat) + checkInputs(n->input(0)->node(), no_exceptions); + case aten::clamp: + // the second and third args do not affect shapes + return checkInputs(n, std::unordered_set{1, 2}); + // after some optimizations we might end up with two Guards back-to-back + // which case we can remove the one whose input is also prim::Guard + case aten::_grad_sum_to_size: + // skip checking size argument + if (checkInputs(n, std::unordered_set{1})) { + auto asize = n->input(1)->node(); + if (asize->kind() == prim::Constant) { return true; + } else if (asize->matches("aten::size(Tensor self) -> int[]")) { + // aten::size is effectively a constant + if (asize->input() + ->type() + ->expect() + ->sizes() + .concrete_sizes()) { + return true; + } } } - } - return false; - - // this is checked by one of the tests in test_jit_fuser.py - case prim::ListUnpack: { - // check if the input is a constant chunk - // used for LSTM fusions - auto chunk = n->input(0)->node(); - if (chunk->kind() != aten::chunk) { return false; + + // this is checked by one of the tests in test_jit_fuser.py + case prim::ListUnpack: { + // check if the input is a constant chunk + // used for LSTM fusions + auto chunk = n->input(0)->node(); + if (chunk->kind() != aten::chunk) { + return false; + } + return checkInputs(chunk, no_exceptions); } - return checkInputs(chunk, no_exceptions); - } - // this is checked by one of the tests in test_jit_fuser.py - case aten::broadcast_tensors: { - auto list_construct = n->input(0)->node(); - if (list_construct->kind() != prim::ListConstruct) { - return false; + // this is checked by one of the tests in test_jit_fuser.py + case aten::broadcast_tensors: { + auto list_construct = n->input(0)->node(); + if (list_construct->kind() != prim::ListConstruct) { + return false; + } + return checkInputs(list_construct, no_exceptions); } - return checkInputs(list_construct, no_exceptions); - } - case prim::Guard: - case prim::GradOf: - return true; - default: - GRAPH_DEBUG("cannot remove ", n->kind().toQualString()); - return false; + case prim::Guard: + case prim::GradOf: + return true; + default: + GRAPH_DEBUG("cannot remove ", n->kind().toQualString()); + return false; } } @@ -321,7 +364,6 @@ struct GuardElimination { static std::unordered_set simple_ops_; }; - void EliminateRedundantGuards(std::shared_ptr graph) { GuardElimination ge(std::move(graph)); ge.run(); diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index baef9360e4dc0..201d74c4c351b 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -254,7 +254,8 @@ struct PeepholeOptimizeImpl { node->output()->replaceAllUsesWith(input_node->input()); changed_ = true; } - } else if (node->matches("aten::size(Tensor self) -> int[]")) { + } else if (node->matches("aten::size(Tensor self) -> int[]") || + node->kind() == prim::shape) { if (auto ptt = node->input()->type()->cast()) { if (auto sizes = ptt->sizes().concrete_sizes()) { WithInsertPoint guard(node); diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp new file mode 100644 index 0000000000000..6f9d346f1d0ca --- /dev/null +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -0,0 +1,351 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace torch::jit; +using namespace torch::jit::tensorexpr; + +namespace torch { +namespace jit { +namespace tensorexpr { + +static bool texpr_fuser_enabled = true; +TORCH_API void SetTexprFuserEnabled(bool val) { + texpr_fuser_enabled = val; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +namespace { + +const Symbol& getTensorExprSymbol() { + static Symbol s = Symbol::fromQualString("tensorexpr::Group"); + return s; +} + +value_list sortReverseTopological( + ArrayRef inputs, + torch::jit::Block* block) { + value_list result; + for (auto i : inputs) { + if (i->node()->owningBlock() == block) { + result.push_back(i); + } + } + // Sort in reverse topological order + std::sort( + result.begin(), + result.end(), + [&](torch::jit::Value* a, torch::jit::Value* b) { + return a->node()->isAfter(b->node()); + }); + return result; +} + +bool isSupported(Node* node) { + // TODO: + switch (node->kind()) { + case aten::add: + case aten::_cast_Float: + case aten::type_as: + case aten::sub: + case aten::mul: + case aten::div: + case aten::eq: + case aten::ne: + case aten::ge: + case aten::gt: + case aten::le: + case aten::lt: + case aten::min: + case aten::max: + case aten::pow: + case aten::clamp: + case aten::lerp: + case aten::log10: + case aten::log: + case aten::log2: + case aten::exp: + case aten::erf: + case aten::erfc: + case aten::fmod: + case aten::cos: + case aten::sin: + case aten::tan: + case aten::acos: + case aten::asin: + case aten::atan: + case aten::atan2: + case aten::cosh: + case aten::sinh: + case aten::tanh: + case aten::sqrt: + case aten::rsqrt: + case aten::abs: + case aten::floor: + case aten::ceil: + case aten::round: + case aten::trunc: + case aten::threshold: + case aten::remainder: + case prim::ConstantChunk: + case aten::cat: + case prim::ListConstruct: + case aten::sigmoid: + case aten::relu: + case aten::addcmul: + case aten::neg: + case aten::reciprocal: + case aten::expm1: + case aten::lgamma: + case aten::slice: + case aten::unsqueeze: + case aten::frac: + case aten::rand_like: + case aten::_sigmoid_backward: + case aten::_tanh_backward: + case aten::__and__: + case aten::__xor__: + case aten::__lshift__: + case aten::__rshift__: + case aten::where: + return true; + default: { + auto& nfr = getNativeFunctionRegistry(); + if (nfr.count(node->kind().toQualString())) { + return true; + } + } + return false; + } +} + +bool canHandle(Node* node, AliasDb& aliasDb) { + if (node->kind() == prim::Constant) { + return true; + } + if (node->kind() == prim::Loop) { + return false; // TODO + } + return isSupported(node); +} + +#define REQ(cond) \ + if (!(cond)) { \ + GRAPH_DEBUG("Failed cond " #cond "\n"); \ + return false; \ + } + +bool canMerge(Node* consumer, Node* producer, AliasDb& aliasDb) { + // Only handle complete tensor types + for (torch::jit::Value* output : consumer->outputs()) { + REQ(output->isCompleteTensor()); + } + + // Only fuse within a block + REQ(consumer->owningBlock() == producer->owningBlock()); + + // Symbolic checks + REQ(canHandle(producer, aliasDb)); + REQ( + (canHandle(consumer, aliasDb) || + consumer->kind() == getTensorExprSymbol())); + + // Alias checks + REQ(aliasDb.couldMoveAfterTopologically(consumer, producer)); + + // Ops that return aliases can only be folded if this is the only use. + if (producer->kind() == aten::slice || producer->kind() == aten::unsqueeze || + producer->kind() == prim::ConstantChunk) { + for (auto& use : producer->output(0)->uses()) { + REQ(use.user == consumer); + } + } + + if (!consumer->hasAttribute(attr::Subgraph) && + consumer->kind() != getTensorExprSymbol()) { + // Don't initiate a fusion group from prim::ListConstruct + REQ(consumer->kind() != prim::ListConstruct); + REQ(consumer->kind() != aten::slice); + REQ(consumer->kind() != aten::unsqueeze); + REQ(consumer->kind() != prim::ConstantChunk); + + // Don't initiate a fusion group just for a constant operand + REQ(producer->kind() != prim::Constant); + } + + if (producer->kind() == aten::cat) { + REQ(producer->inputs()[0]->node()->kind() == prim::ListConstruct); + REQ(producer->inputs()[0]->uses().size() == 1); + REQ(producer->inputs()[1]->node()->kind() == prim::Constant); + } else if (consumer->kind() == aten::cat) { + REQ(consumer->inputs()[0]->node()->kind() == prim::ListConstruct); + REQ(consumer->inputs()[0]->uses().size() == 1); + REQ(consumer->inputs()[1]->node()->kind() == prim::Constant); + } + + return true; +} +#undef REQ + +Node* getOrCreateTensorExprSubgraph(Node* n) { + if (n->hasAttribute(attr::Subgraph) && n->kind() == getTensorExprSymbol()) { + return n; + } + auto te_group = + SubgraphUtils::createSingletonSubgraph(n, getTensorExprSymbol()); + GRAPH_UPDATE("getOrCreateTensorExprSubgraph: ", *te_group); + return te_group; +} + +c10::optional tryMerge( + Node* consumer, + Node* producer, + AliasDb& aliasDb) { + GRAPH_DEBUG( + "Trying producer ", + getHeader(producer), + " and consumer ", + getHeader(consumer), + ":\n"); + + if (!canMerge(consumer, producer, aliasDb)) { + return c10::nullopt; + } + + consumer = getOrCreateTensorExprSubgraph(consumer); + + if (producer->kind() == aten::cat) { + Node* listconstruct = producer->inputs()[0]->node(); + + aliasDb.moveAfterTopologicallyValid(consumer, producer); + GRAPH_UPDATE( + "Merging ", getHeader(producer), " into ", getHeader(consumer)); + SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer); + + aliasDb.moveAfterTopologicallyValid(consumer, listconstruct); + GRAPH_UPDATE( + "Merging ", getHeader(listconstruct), " into ", getHeader(consumer)); + SubgraphUtils::mergeNodeIntoSubgraph(listconstruct, consumer); + } else { + aliasDb.moveAfterTopologicallyValid(consumer, producer); + GRAPH_UPDATE( + "Merging ", getHeader(producer), " into ", getHeader(consumer)); + SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer); + } + + return consumer; +} + +std::pair scanNode( + Node* consumer, + AliasDb& aliasDb) { + auto inputs = + sortReverseTopological(consumer->inputs(), consumer->owningBlock()); + + // Grab the iterator below consumer. We'll use that to determine + // where to resume iteration, even if consumer gets relocated within + // the block. + auto iter = --consumer->reverseIterator(); + for (auto input : inputs) { + if (auto group = tryMerge(consumer, input->node(), aliasDb)) { + // Resume iteration from where consumer is/used to be. + return {++iter, true}; + } + } + + // We know consumer didn't move, so skip over it. + return {++(++iter), false}; +} + +void fuseTensorExprs(std::shared_ptr& graph) { + if (!texpr_fuser_enabled) { + return; + } + GRAPH_DUMP("Before TExprFuser: ", graph); + + // Get rid of dead code so that we don't waste effort fusing it. + EliminateDeadCode(graph); + + AliasDb aliasDb(graph); + auto block = graph->block(); + + std::vector> + worklist; + std::unordered_set visited_blocks; + + bool any_changed = true; + while (any_changed) { + any_changed = false; + worklist.push_back({block->nodes().rbegin(), block->nodes().rend()}); + + while (worklist.size()) { + auto& it = worklist.back().first; + auto end = worklist.back().second; + + if (it->blocks().size()) { + Node* n = *it; + ++it; + + if (it == end) { + worklist.pop_back(); + } + + for (auto b : n->blocks()) { + if (!visited_blocks.count(b)) { + worklist.push_back({b->nodes().rbegin(), b->nodes().rend()}); + visited_blocks.insert(b); + } + } + } else { + bool changed; + std::tie(it, changed) = scanNode(*it, aliasDb); + any_changed |= changed; + if (it == end) { + worklist.pop_back(); + } + } + } + } + + EliminateCommonSubexpression(graph); + EliminateDeadCode(graph); + + GRAPH_DUMP("After TExprFuser: ", graph); +} + +Operation createTensorExprOp(const Node* node) { + auto kernel = std::make_shared(*node->g(attr::Subgraph)); + return [kernel](Stack& stack) { + RECORD_FUNCTION("TensorExpr", std::vector()); + kernel->run(stack); + return 0; + }; +} + +c10::OperatorOptions getAliasAnalysisOption(AliasAnalysisKind k) { + auto options = c10::OperatorOptions(); + options.setAliasAnalysis(k); + return options; +} + +RegisterOperators TensorExprOps({ + torch::jit::Operator( + getTensorExprSymbol(), + createTensorExprOp, + getAliasAnalysisOption(AliasAnalysisKind::PURE_FUNCTION)), +}); + +RegisterPass pass(fuseTensorExprs); + +} // namespace diff --git a/torch/csrc/jit/tensorexpr/DesignOverview.md b/torch/csrc/jit/tensorexpr/DesignOverview.md new file mode 100644 index 0000000000000..28afe53fa1ebb --- /dev/null +++ b/torch/csrc/jit/tensorexpr/DesignOverview.md @@ -0,0 +1,113 @@ +# Current workflow + +## Step 1: input from the user. + +User construct a kernel from tensor expressions, like: +``` + Buffer a_buf("a", kFloat32, {M, N}); + Buffer b_buf("b", kFloat32, {N, K}); + Buffer c_buf("c", kFloat32, {M, N}); + Buffer d_buf("d", kFloat32, {M, K}); + + Tensor* x = Compute( + "x", + {{M, "m1"}, {N, "n1"}, {K, "k1"}}, + [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) { + return a_buf(m, n) * b_buf(n, k); + }); + Tensor* y = ...; + Tensor* z = ...; + std::vector tensors_to_compute = {x, z}; // Tensor y might be used in x or z - in this case it will also be computed. +``` + +## Step 2: Create schedule for the tensor expressions: +``` + Schedule s(tensors_to_compute); +``` +This constructs a tree-like data structure (`TensorExprNode`) representing loop nests for the given tensor computation. +A node in this IR is either a loop-axis(LoopAxis) or a tensor expression (`TensorExprOp`). +If it is a loop-axis, it also contains children that again might be either a loop-axes or a tensor expression, and so on. +If it is a tensor-expression, it is lowered to a statement (`Stmt`). Currently, it just means that we're creating a `Store` for every tensor-expression. We also keep a pointer to the original tensor expression. +It could look like this: +``` +loop-axis i + loop-axis j + Store(to: a[i, j], what: x[i] + y[j]) +loop-axis k + loop-axis l + Store(to: b[k, l], what: a[i, j] + 1) + loop-axis m + Store(to: c[k,l,m], what: b[k,l] + z[m]) +``` + +## Step 3: Apply scheduling primitives +Scheduling primitives mutate the tree structure: they can create or remove loop-axis, replace statements with other statements (updates `element_stmt` for each affected tensor expression) or remove them. The transformations also record the history. +The output of this step is a modified tree-like structure (same format as in step 2). + +## Step 4: Lower the tree structure to statements. +This step creates a `For` statement for each loop-axis and emits `element_stmt` for bodies of the loops. + +## Step 5: Pass the final statement for codegen (LLVM/CUDA/IREval) +Codegen is implemented as an IR visitor over the statements produced in the previous step. + +# Tensor Expressions Language +There are several core concepts in the Tensor Expression engine, this section defines them and shows how they connect to each other. + +## Expr +Expr represents a node in the abstract syntax tree of a tensor expression. Leaf nodes in such tree are either a symbolic variable (`Var`), a constant (`IntImm` or `FloatImm`), `Buffer`, or a `Tensor`. Non-leaf nodes refer to other expressions and represent various operations. E.g. `Add` has two operands: `lhs` and `rhs`, both of which are also `Expr`. + +## Tensor +`Tensor` is a bundle of +1) a variable `Var` defining which tensor this `Tensor` expression is describing +2) a list of indices `args` (each of them is `Var`) +3) a list of expressions for dimensions `dims` (each of them is `Expr`) +4) a computational expression `body` (of `Expr` type) + +## Buffer +`Buffer`s are essentially `Tensor`s without a `body` - they represent an indexed access to "tensors" that is outsied the tensor-expression system. +`Buffer` is a bundle of +1) a `Var` defining which buffer this `Buffer` expression is defining +2) a list of indices `args` (each of them is `Var`) +3) a list of expressions for dimensions `dims` (each of them is `Expr`) + +## Example +Suppose we'd like to represent the following expression: +``` +A[i,j] = B[i,j] + 7 +``` +where both `A` and `B` are 100x100 tensors. +On the top level we would have a single `Tensor` expression with: +1) a variable referring to "A" +2) list of two indices referring to "i" and "j" +3) list of two `IntImm` constants describing sizes (both of them would carry the value of 100) +4) a body expression which is an `Add` with two operands: `Buffer` describing `B[i,j]` access and an `IntImm` constant `7`. + +The buffer expression describing `B[i,j]` would have similar properties: +1) a variable referring to "B" +2) list of two indices referring to "i" and "j" +3) list of two `IntImm` constants describing sizes (both of them would carry the value of 100) + +In contrast to the tensor expression, the buffer expression would not have a body - it represents a symbolic access. + +The code for constructing such an expression could look like this: + +``` + Buffer B("B", kFloat32, {100, 100}); + Tensor* A = Compute( + "A", + {{100, "i"}, {100, "j"}}, + [&](const VarHandle& i, const VarHandle& j) { + return B(i, j) + 7; + }); +``` + +## Function +`Function` represents several tensor computations bundled together. In fact, `Tensor`s are implemented via `Function`s. A function allows us to specify that several different tensor expressions operate over the same set of indices and dimensions. + +## Stmt +`Stmt`s are what tensor expressions are lowered to before the codegen. They represent the computation in a less abstract way, compared to pure tensor expressions. Statements are built upon expressions, i.e. they can contain expressions as operands. Statement is a unit that a codegen works with, it is incorrect to try to pass an expression to a codegen. +An example of statements are `Store` and `For`. +TODO: provide more detailed example/description for the stmt. + +# Memory model +TBD diff --git a/torch/csrc/jit/tensorexpr/HowToRebaseOnMaster.md b/torch/csrc/jit/tensorexpr/HowToRebaseOnMaster.md new file mode 100644 index 0000000000000..408d5b55fd160 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/HowToRebaseOnMaster.md @@ -0,0 +1,64 @@ +1. Make sure both Bert's repo and the official pytorch repo are added as remotes. + +``` +$ git remote -v +bert git@github.com:bertmaher/pytorch.git (fetch) +bert git@github.com:bertmaher/pytorch.git (push) +origin git@github.com:pytorch/pytorch.git (fetch) +origin git@github.com:pytorch/pytorch.git (push) +... +``` +You might see https address instead of the ssh one (e.g. `https://github.com/pytorch/pytorch.git`), which should also be fine if you only plan to pull from it. + +If you don't have these remotes, add the missing ones with +``` +git remote add +``` + +E.g. +``` +git remote add pt https://github.com/pytorch/pytorch.git +``` + +You can remove a remote if you need with +``` +git remote remove +``` + +2. Fetch all the remotes: +``` +git fetch --all +``` + +3. Stash/commit all your local changes +``` +git stash # OR +git commit -a -m "My local changes" +``` + +4. Checkout branch that you'd like to rebase on top of the master. Assuming we'd want to rebase the `pytorch_fusion` branch from Bert's repo, you could do: +``` +git checkout pytorch_fusion # Checkout local 'pytorch_fusion' branch +git reset --hard bert/pytorch_fusion # This will replace the current, 'pytorch_fusion', branch with the version from Bert's repo +``` + +5. Rebase your branch on top of the latest master branch: +``` +git rebase origin/master +``` +If you're lucky and there are not conflicts, you will end up with a rebased branch. +In the other case, manually resolve the conflicts: for every conflict, do: + - `git status` to find "both modified" files - that's where the conflicts are + - Manually edit these files to resolve the conflict. + - Mark the conflict as resolved by adding these files with `git add FILENAME` + - Once conflicts in all files are resolved, run `git rebase --continue` + - At any point you can run `git rebase --abort` and you will escape to the state before the rebase step. + +6. Push to our (Bert's repo). That will have to be a force-push, so make sure to: + - Double check what you're going to push (e.g. with `git log`) - compare that the new branch and the old branch (`bert/pytorch_fusion`) have the same commits on top, the only difference is the last master commit in the branch. + - Announce that you're going to force-push the main branch. Other people will have to rebase their changes after that. + - Push with local branch 'pytorch_fusion' to the Bert's repo under the same name: `git push bert -f pytorch_fusion:pytorch_fusion` + +7. ... + +8. Profit! diff --git a/torch/csrc/jit/tensorexpr/buffer.cpp b/torch/csrc/jit/tensorexpr/buffer.cpp new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/torch/csrc/jit/tensorexpr/buffer.h b/torch/csrc/jit/tensorexpr/buffer.h new file mode 100644 index 0000000000000..fcdf6784b9987 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/buffer.h @@ -0,0 +1,108 @@ +#pragma once + +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +class Buffer { + public: + Buffer(const VarHandle& data, const Dtype& dtype, const std::vector& dims) + : data_(data.node()), dtype_(dtype), dims_(ExprHandleVectorToExprVector(dims)) { + CHECK_EQ(data.dtype(), kHandle); + std::vector stride_handles(dims.size()); + for (int i = ndim() - 1; i >= 0; i--) { + if (i == ndim() - 1) { + stride_handles[i] = 1; + } else { + stride_handles[i] = stride_handles[i + 1] * ExprHandle(dim(i + 1)); + } + } + strides_ = ExprHandleVectorToExprVector(stride_handles); + } + Buffer( + const std::string& name, + const Dtype& dtype, + const std::vector& dims) + : Buffer(VarHandle(name, kHandle), dtype, dims) {} + + const Var* data() const { + return data_; + } + const Dtype& dtype() const { + return dtype_; + } + int ndim() const { + return dims_.size(); + } + const Expr* dim(int index) const { + return dims_[index]; + } + + // TODO: consider defer the storage flatten to a later stage. + template + ExprHandle operator()(Args... args) const { + ExprHandle index = Index(std::forward(args)...); + return LoadValue(index); + } + + template + ExprHandle call(const std::vector& args) const { + std::vector params(args.begin(), args.end()); + ExprHandle index = Index(params); + return LoadValue(index); + } + + private: + ExprHandle Index(const ExprHandle& x) const { + CHECK(ndim() == 1); + return x; + } + ExprHandle Index(const ExprHandle& x, const ExprHandle& y) const { + CHECK(ndim() == 2); + return x * ExprHandle(strides_[0]) + y; + } + ExprHandle Index(const ExprHandle& x, const ExprHandle& y, const ExprHandle& z) const { + CHECK(ndim() == 3); + return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) + z; + } + ExprHandle Index(const ExprHandle& x, const ExprHandle& y, const ExprHandle& z, const ExprHandle& w) const { + CHECK(ndim() == 4); + return x * ExprHandle(strides_[0]) + y * ExprHandle(strides_[1]) + z * ExprHandle(strides_[2]) + w; + } + ExprHandle Index(const std::vector& indices) const { + CHECK(ndim() == (int)indices.size()); + ExprHandle total_index; + for (size_t i = 0; i < indices.size(); i++) { + ExprHandle index; + if (i == indices.size() - 1) { + index = indices[i]; + } else { + index = indices[i] * ExprHandle(strides_[i]); + } + if (i == 0) { + total_index = index; + } else { + total_index = total_index + index; + } + } + return total_index; + } + + ExprHandle LoadValue(const ExprHandle& index) const; + + const Var* data_; + Dtype dtype_; + std::vector dims_; + std::vector strides_; + // TODO: add strides +}; + +inline ExprHandle Buffer::LoadValue(const ExprHandle& index) const { + return Load::make(*this, index, ExprHandle(1)); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/codegen.cpp b/torch/csrc/jit/tensorexpr/codegen.cpp new file mode 100644 index 0000000000000..be4a171f335bd --- /dev/null +++ b/torch/csrc/jit/tensorexpr/codegen.cpp @@ -0,0 +1,51 @@ +#include "torch/csrc/jit/tensorexpr/codegen.h" + +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +RegisterCodeGenList::StmtFactoryMethod RegisterCodeGenList:: + FindStmtFactoryMethod(const std::string& name) { + auto iter = stmt_factory_methods_.find(name); + if (iter == stmt_factory_methods_.end()) { + std::ostringstream oss; + oss << "Invalid stmt codegen name: " << name << ". "; + oss << "Existing codegen names: ["; + int index = 0; + for (const auto& entry : stmt_factory_methods_) { + if (index != 0) { + oss << ", "; + } + oss << entry.first; + index++; + } + oss << "]"; + throw std::runtime_error(oss.str()); + } + return iter->second; +} + +void RegisterCodeGenList::AddStmtFactoryMethod( + const std::string& name, + const StmtFactoryMethod& stmt_factory_method) { + auto insert_ret = + stmt_factory_methods_.insert(std::make_pair(name, stmt_factory_method)); + if (!insert_ret.second) { + throw std::runtime_error("Duplicated CodeGen names: " + name); + } +} + +std::unique_ptr CreateCodeGen( + const std::string& name, + Stmt* stmt, + const std::vector& params) { + RegisterCodeGenList::StmtFactoryMethod method = + RegisterCodeGenList::GetInstance().FindStmtFactoryMethod(name); + return method(stmt, params); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/codegen.h b/torch/csrc/jit/tensorexpr/codegen.h new file mode 100644 index 0000000000000..3883086a4f166 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/codegen.h @@ -0,0 +1,171 @@ +#pragma once + +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +class PaddedBuffer; + +class CodeGen { + public: + class BufferArg; + class CallArg; + + template + CodeGen(Stmt* stmt, Ts... ts) + : stmt_(stmt), buffer_args_({BufferArg(ts)...}) {} + + CodeGen(Stmt* stmt, const std::vector& buffer_args) + : stmt_(stmt), buffer_args_(buffer_args) {} + + virtual ~CodeGen() {} + + Stmt* stmt() const { + return stmt_; + } + + std::vector& buffer_args() { + return buffer_args_; + } + + const std::vector& buffer_args() const { + return buffer_args_; + } + + TORCH_API virtual void call(const std::vector& args) { + LOG(FATAL) << "unimplemented call"; + } + + private: + Stmt* stmt_; + std::vector buffer_args_; +}; + +class CodeGen::BufferArg { + public: + BufferArg(const Buffer& buffer) + : var_(buffer.data()), dtype_(buffer.dtype()) {} + BufferArg(Tensor* tensor) + : var_(tensor->function()->func_var(tensor->output_index())), + dtype_(tensor->function()->body(tensor->output_index())->dtype()) {} + BufferArg(const Function& func) + : var_(func.func_var(0)), dtype_(func.body(0)->dtype()) { + // TODO: Support multiple-output functions + CHECK(func.func_vars().size() == 1); + } + BufferArg(const VarHandle& var) : var_(var.node()), dtype_(var.dtype()), isVar_(true) {} + + const Var* var() const { + return var_; + } + Dtype dtype() const { + return dtype_; + } + + bool isVar() const { + return isVar_; + } + + private: + const Var* var_; + Dtype dtype_; + bool isVar_{false}; +}; + +class CodeGen::CallArg { + public: + template + CallArg(const PaddedBuffer& buffer); + + template + CallArg(const std::vector& buffer) : ptr_(const_cast(buffer.data())) {} + + CallArg(void* ptr) : ptr_(ptr) {} + +#define ARG_TYPE_CTOR(Type, Name) \ + CallArg(Type v) : Name##val_(v) {} + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_TYPE_CTOR); +#undef ARG_TYPE_CTOR + + void* data() const { + return ptr_; + } + +#define ARG_DATA_DEFINE(Type, Name) \ + Type Name##Data() const { \ + return Name##val_; \ + } + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_DATA_DEFINE); +#undef ARG_DATA_DEFINE + +#define ARG_PTR_DEFINE(Type, Name) \ + Type* Name##Ptr() const { \ + return const_cast(&Name##val_); \ + } + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_PTR_DEFINE); +#undef ARG_PTR_DEFINE + + private: + union { + void* ptr_; + +#define ARG_BACKING(Type, Name) \ + Type Name##val_; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, ARG_BACKING); +#undef ARG_BACKING + }; +}; + +class RegisterCodeGenList { + public: + TORCH_API static RegisterCodeGenList& GetInstance() { + static RegisterCodeGenList codegen_list; + return codegen_list; + } + + using StmtFactoryMethod = std::function( + Stmt* stmt, + const std::vector&)>; + + TORCH_API StmtFactoryMethod FindStmtFactoryMethod(const std::string& name); + + private: + template + friend class RegisterCodeGen; + RegisterCodeGenList() {} + TORCH_API void AddStmtFactoryMethod( + const std::string& name, + const StmtFactoryMethod& stmt_factory_method); + RegisterCodeGenList(const RegisterCodeGenList&) = delete; + RegisterCodeGenList& operator=(const RegisterCodeGenList&) = delete; + + std::unordered_map stmt_factory_methods_; +}; + +template +class RegisterCodeGen { + public: + explicit RegisterCodeGen(const std::string& name) { + RegisterCodeGenList& codegen_list = RegisterCodeGenList::GetInstance(); + codegen_list.AddStmtFactoryMethod( + name, + [](Stmt* stmt, const std::vector& params) { + std::unique_ptr method(new CodeGenType(stmt, params)); + return method; + }); + } +}; + +TORCH_API std::unique_ptr CreateCodeGen( + const std::string& name, + Stmt* stmt, + const std::vector& params); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.cpp b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp new file mode 100644 index 0000000000000..5c1d99f0c876b --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.cpp @@ -0,0 +1,656 @@ +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" +#include "torch/csrc/jit/tensorexpr/cuda_half_support.h" + +#include "ATen/CUDAGenerator.h" +#include "c10/cuda/CUDAFunctions.h" +#include "torch/csrc/jit/tensorexpr/cuda_random.h" +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/execution_counter.h" + +#define DEBUG_PRINT 0 + +namespace torch { +namespace jit { +namespace tensorexpr { + +DEFINE_TRIGGER(cuda_codegen_created); +DEFINE_TRIGGER(cuda_codegen_executed); + +// A RAII wrapper to manage a variable and name pair in the look-up table. +// TODO: move this to a more shared place. +class ScopedVarName { + public: + ScopedVarName( + VarNameMap* mapping, + const Var* var, + const std::string& name) + : mapping_(mapping), var_(var) { + auto iter = mapping->find(var); + if (iter != mapping->end()) { + throw std::runtime_error("Duplicate var entry: " + var->name_hint()); + } + mapping->insert(std::make_pair(var, name)); + } + + ScopedVarName( + UniqueNameManager* manager, + const Var* var, + const std::string& name) + : ScopedVarName(&manager->unique_name_mapping_, var, name) {} + + ~ScopedVarName() noexcept(false) { + auto iter = mapping_->find(var_); + TORCH_CHECK(iter != mapping_->end(), "Invalid var entry"); + mapping_->erase(var_); + } + + private: + ScopedVarName(const ScopedVarName&) = delete; + ScopedVarName& operator=(const ScopedVarName&) = delete; + + VarNameMap* mapping_ = nullptr; + const Var* var_ = nullptr; +}; + +static int as_int(const Expr* expr) { + auto v = dynamic_cast(expr); + TORCH_CHECK(v, "Expression is not an integer constant"); + return v->value(); +} + +static bool is_zero(const Expr* expr) { + return as_int(expr) == 0; +} + +static const at::cuda::NVRTC& nvrtc() { + return at::globalContext().getNVRTC(); +} + +static void getMajorMinor( + const cudaDeviceProp* const prop, + int& major, + int& minor) { + using CudaVersion = std::pair; + CudaVersion nvrtc_version; + AT_CUDA_NVRTC_CHECK( + nvrtc().nvrtcVersion(&nvrtc_version.first, &nvrtc_version.second)); + + AT_ASSERT(nvrtc_version.first >= 6); + + CudaVersion dev_version = CudaVersion(prop->major, prop->minor); + CudaVersion max_dev_version(dev_version); + if (nvrtc_version.first <= 7) { // 7 supports 2-5.x + max_dev_version = CudaVersion(5, 0); + } else if (nvrtc_version.first <= 8) { // 8 supports 2-6.x + max_dev_version = CudaVersion(6, 0); + } else if (nvrtc_version.first <= 9) { // 9 supports 3-7.2 + max_dev_version = CudaVersion(7, 2); + } else if (nvrtc_version.first <= 10) { // 10 supports 3-7.5 + max_dev_version = CudaVersion(7, 5); + } + if (dev_version > max_dev_version) { + dev_version = max_dev_version; + } + major = dev_version.first; + minor = dev_version.second; +} + +void CudaPrinter::visit(const For* v) { + const LoopOptions& loop_options = v->loop_options(); + if (loop_options.is_gpu_block_index()) { + ScopedVarName var_name( + name_manager(), v->var(), loop_options.gpu_block_index_str()); + v->body()->accept(this); + int gpu_block_index = loop_options.gpu_block_index(); + if (gpu_block_extents_.size() <= gpu_block_index) { + gpu_block_extents_.resize(gpu_block_index + 1); + } + if (!is_zero(v->start())) { + throw std::runtime_error( + "start must be zero for gpu_block_index: " + + std::to_string(ExprHandle(v->start()))); + } + gpu_block_extents_[gpu_block_index] = v->stop(); + } else if (loop_options.is_gpu_thread_index()) { + ScopedVarName var_name( + name_manager(), v->var(), loop_options.gpu_thread_index_str()); + v->body()->accept(this); + int gpu_thread_index = loop_options.gpu_thread_index(); + if (gpu_thread_extents_.size() <= gpu_thread_index) { + gpu_thread_extents_.resize(gpu_thread_index + 1); + } + if (!is_zero(v->start())) { + throw std::runtime_error( + "start must be zero for gpu_block_index: " + + std::to_string(ExprHandle(v->start()))); + } + gpu_thread_extents_[gpu_thread_index] = v->stop(); + } else { + IRPrinter::visit(v); + } +} + +void CudaPrinter::visit(const Intrinsics* v) { + if (v->op_type() == IntrinsicsOp::kRand) { + os() << "Uint32ToFloat(" << *rand_func_ << "())"; + return; + } + + std::string func_name = v->func_name(); + + // get type of resulting expression. + ScalarType returnType = v->param(0)->dtype().scalar_type(); + for (int i = 1; i < v->nparams(); ++i) { + returnType = + promoteTypes(returnType, v->param(i)->dtype().scalar_type()); + } + + if (returnType == ScalarType::Half || returnType == ScalarType::Float) { + func_name = func_name + "f"; + } + + os() << func_name << "("; + for (int i = 0; i < v->nparams(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *v->param(i); + } + os() << ")"; +} + +void CudaPrinter::visit(const Load* v) { + // TODO: find a better metric in using ldg or not. Support different dtypes. + if (v->dtype().scalar_type() == ScalarType::Half) { + os() << "__half2float(" << *v->base_handle() << "[" << *v->index() << "])"; + } else { + os() << "__ldg(" << *v->base_handle() << " + " << *v->index() << ")"; + } +} + +void CudaPrinter::visit(const Store* v) { + os() << *v->base_handle() << "[" << *v->index() << "] = "; + if (v->value()->dtype().scalar_type() == ScalarType::Half) { + os() << "__float2half(" << *v->value() << ");"; + } else { + os() << *v->value() << ";"; + } +} + +void CudaPrinter::visit(const Max* v) { + auto dtype = v->dtype().scalar_type(); + switch (dtype) { + case ScalarType::Half: + // doing Half math in float. + case ScalarType::Float: + os() << "fmaxf"; + break; + case ScalarType::Double: + os() << "fmax"; + break; + default: + os() << "max"; + break; + } + os() << "("; + v->lhs()->accept(this); + os() << ","; + v->rhs()->accept(this); + os() << ")"; +} + +void CudaPrinter::visit(const Min* v) { + auto dtype = v->dtype().scalar_type(); + switch (dtype) { + case ScalarType::Half: + // doing Half math in float. + case ScalarType::Float: + os() << "fminf"; + break; + case ScalarType::Double: + os() << "fmin"; + break; + default: + os() << "min"; + break; + } + os() << "("; + v->lhs()->accept(this); + os() << ","; + v->rhs()->accept(this); + os() << ")"; +} + +std::string cudaDtypeCppString(const Dtype& dtype) { + switch (dtype.scalar_type()) { + case ScalarType::Half: + return "half"; + case ScalarType::Char: + return "char"; + case ScalarType::Byte: + return "unsigned char"; + case ScalarType::Short: + return "short"; + case ScalarType::Long: + return "long"; + default: + ;/* nothing */ + } + return dtype.ToCppString(); +} + +void CudaPrinter::visit(const LetStmt* v) { + const Var* var = v->var(); + if (var->dtype().scalar_type() == ScalarType::Half) { + // we do math in floats so use that. + os() << "float"; + } else { + os() << cudaDtypeCppString(var->dtype()); + } + os() << " " << *var << " = " << *v->value() << "; " + << std::endl; + v->body()->accept(this); +} + +void CudaPrinter::visit(const IfThenElse* v) { + os() << "("; + v->condition()->accept(this); + os() << ") ? "; + v->true_value()->accept(this); + os() << " : "; + v->false_value()->accept(this); +} + +class PrioritizeLoad : public IRMutator { + public: + virtual const Expr* mutate(const Load* v) { + MemLoadList& load_list = load_stack_.back(); + const Var* load_new_var = new Var("v", v->dtype()); + const Expr* new_value = IRMutator::mutate(v); + load_list.push_back(std::make_pair(load_new_var, new_value)); + return load_new_var; + } + + // TODO: merge this with the IRMutator::mutate version. + virtual Stmt* mutate(const For* v) { + const Var* var = v->var(); + const Expr* start = v->start(); + const Expr* stop = v->stop(); + Stmt* body = v->body(); + LoopOptions loop_options = v->loop_options(); + const Var* var_new = dynamic_cast(var->accept_mutator(this)); + const Expr* start_new = start->accept_mutator(this); + const Expr* stop_new = stop->accept_mutator(this); + PushList(); + Stmt* body_new = body->accept_mutator(this); + if (!body_new) { + return nullptr; + } + Stmt* body_with_loads = AddMemLoadsFromList(body_new); + PopList(); + if (var == var_new && start == start_new && + stop == stop_new && body == body_with_loads) { + return (Stmt*)v; + } + return new For( + var_new, start_new, stop_new, body_with_loads, loop_options); + } + + virtual Stmt* mutate(const LetStmt* v) { + const Var* var = v->var(); + const Expr* value = v->value(); + Stmt* body = v->body(); + const Var* var_new = dynamic_cast(var->accept_mutator(this)); + if (var_new == nullptr) { + throw std::runtime_error("LetStmt var must be variable"); + } + const Expr* value_new = value->accept_mutator(this); + PushList(); + Stmt* body_new = body->accept_mutator(this); + Stmt* body_with_loads = AddMemLoadsFromList(body_new); + PopList(); + if (var == var_new && value == value_new && + body == body_with_loads) { + return (Stmt*)v; + } + return new LetStmt(var_new, value_new, body_with_loads); + } + + virtual Stmt* mutate(const Cond* v) { + const Expr* cond_old = v->condition(); + Stmt* true_old = v->true_stmt(); + Stmt* false_old = v->false_stmt(); + + const Expr* cond_new = cond_old->accept_mutator(this); + PushList(); + Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; + Stmt* true_with_loads = AddMemLoadsFromList(true_new); + PopList(); + PushList(); + Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; + Stmt* false_with_loads = AddMemLoadsFromList(false_new); + PopList(); + + if (cond_old == cond_new && true_old == true_with_loads && + false_old == false_with_loads) { + return (Stmt*)v; + } + return new Cond(cond_new, true_with_loads, false_with_loads); + } + + Stmt* Process(Stmt* stmt) { + this->PushList(); + Stmt* stmt_v = stmt; + Stmt* stmt_new = stmt_v->accept_mutator(this); + Stmt* stmt_with_loads = AddMemLoadsFromList(stmt_new); + this->PopList(); + return stmt_with_loads; + } + + private: + using MemLoadEntry = std::pair; + using MemLoadList = std::vector; + using MemoryLoadStack = std::vector; + + void PushList() { + load_stack_.push_back(MemLoadList()); + } + + void PopList() { + load_stack_.pop_back(); + } + + Stmt* AddMemLoadsFromList(Stmt* stmt) { + MemLoadList& load_list = load_stack_.back(); + Stmt* stmt_v = stmt; + for (int i = load_list.size() - 1; i >= 0; i--) { + const MemLoadEntry& entry = load_list[i]; + Var* var_ptr = const_cast(entry.first); + stmt_v = new LetStmt(var_ptr, entry.second, stmt_v); + } + return stmt_v; + } + + MemoryLoadStack load_stack_; +}; + +class HasRand : public IRVisitor { + public: + HasRand(Stmt* stmt) : stmt_(stmt) { + stmt_->accept(this); + } + + bool has_rand() const { + return has_rand_; + } + + private: + virtual void visit(const Intrinsics* v) { + if (v->op_type() == IntrinsicsOp::kRand) { + has_rand_ = true; + } else { + IRVisitor::visit(v); + } + } + Stmt* stmt_; + bool has_rand_ = false; +}; + +void CudaCodeGen::Initialize() { + // TODO: handle multiple kernels. + // TODO: handle dynamic dimension. + // TODO: call nvrtc. + HasRand has_rand_func(stmt()); + has_random_ = has_rand_func.has_rand(); + printer_.reset(new CudaPrinter(&oss_, has_random_)); + if (has_random_) { + os() << philox_random_string << std::endl; + } + + // Check whether the statement uses the Half type, if so add the + // half_support_literal. + CudaHalfChecker halfChecker; + stmt()->accept(&halfChecker); + if (halfChecker.hasHalf()) { + os() << fuser::cuda::half_support_literal << std::endl; + } + + os() << "extern \"C\" __global__" << std::endl << "void f("; + const std::vector buffer_args = this->buffer_args(); + for (int i = 0; i < buffer_args.size(); i++) { + if (i > 0) { + os() << ", "; + } + const BufferArg& buffer_arg = buffer_args[i]; + const Var* var = buffer_arg.var(); + Dtype dtype = buffer_arg.dtype(); + + os() << cudaDtypeCppString(dtype) + << (buffer_arg.isVar() ? " " : "* ") + << name_manager()->get_unique_name(var); + } + const Var* rand_seed; + const Var* rand_offset; + if (has_random_) { + // TODO: switch to kUint64 when it is available. + rand_seed = new Var("rand_seed", kInt); + rand_offset = new Var("rand_offset", kInt); + std::string uint64_str = "unsigned long long"; + os() << ", " << uint64_str << " " << *rand_seed << ", " << uint64_str << " " + << *rand_offset; + } + os() << ") {"; + os() << std::endl; + + if (has_random_) { + const Var* idx = new Var("idx", kInt); + os() << "int " << *idx << " = blockIdx.x*blockDim.x + threadIdx.x;" + << std::endl; + const Var* rand_func = printer_->rand_func(); + os() << "Philox " << *rand_func << "(" << *rand_seed << ", " << *idx << ", " + << *rand_offset << ");" << std::endl; + os() << std::endl; + } + + Stmt* stmt_v = stmt(); + PrioritizeLoad prioritize_load; + stmt_v = prioritize_load.Process(stmt_v); + stmt_v->accept(printer_.get()); + os() << std::endl; + os() << "}"; + + // Check that all block extents had been set. + const std::vector& gpu_block_extents = printer_->gpu_block_extents(); + const std::vector& gpu_thread_extents = printer_->gpu_thread_extents(); + for (int i = 0; i < gpu_block_extents.size(); i++) { + if (!gpu_block_extents[i]) { + throw std::runtime_error("Missing gpu_block_index: " + std::to_string(i)); + } + } + +#if DEBUG_PRINT + std::cout << "stmt: " << std::endl; + std::cout << oss_.str() << std::endl; + std::cout << "block("; + for (int i = 0; i < gpu_block_extents.size(); i++) { + if (i > 0) { + std::cout << ", "; + } + std::cout << gpu_block_extents[i]; + } + std::cout << "), thread("; + for (int i = 0; i < gpu_thread_extents.size(); i++) { + if (i > 0) { + std::cout << ", "; + } + std::cout << gpu_thread_extents[i]; + } + std::cout << ")" << std::endl; + ; +#endif + + CompileToNVRTC(oss_.str()); + USE_TRIGGER(cuda_codegen_created); +} + +void CudaCodeGen::call(const std::vector& args) { + CHECK_EQ(args.size(), buffer_args().size()); + + // TODO: move as much of this into the constructors. + const std::vector& gpu_block_extents = printer_->gpu_block_extents(); + const std::vector& gpu_thread_extents = printer_->gpu_thread_extents(); + CHECK(gpu_block_extents.size() <= 3); + CHECK(gpu_thread_extents.size() <= 3); + std::vector gpu_block_extents_v(3, 1); + std::vector gpu_thread_extents_v(3, 1); + // evaluate all the block/thread extents into values + // TODO: eventually, codegen these calculations and make them part of the + // module. + for (int i = 0; i < gpu_block_extents.size(); i++) { + ExprEval eval( + ExprHandle(gpu_block_extents[i]), buffer_args()); + gpu_block_extents_v[i] = eval.value(args); + } + for (int i = 0; i < gpu_thread_extents.size(); i++) { + ExprEval eval( + ExprHandle(gpu_thread_extents[i]), buffer_args()); + gpu_thread_extents_v[i] = eval.value(args); + } + + // Bind the buffer addresses into arguments + auto const& buffer_args = this->buffer_args(); + int ptr_count = buffer_args.size(); + if (has_random_) { + ptr_count += 2; + } + std::vector args_data(buffer_args.size()); + std::vector ptr_to_args(ptr_count); + uint64_t rand_seed = uint64_t(-1); + uint64_t rand_offset = uint64_t(-1); + for (int i = 0; i < buffer_args.size(); i++) { + auto const& bufferArg = buffer_args[i]; + if (bufferArg.isVar()) { + auto stype = bufferArg.dtype().scalar_type(); + switch (stype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + ptr_to_args[i] = args[i].Name##Ptr(); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unhandled dtype in argument"; + } + } else { + args_data[i] = args[i].data(); + ptr_to_args[i] = &args_data[i]; + } + } + + if (has_random_) { + auto gen = at::cuda::detail::getDefaultCUDAGenerator(); + // TODO: total hack. Switch to numel when it is available. + int64_t total_elements_per_thread = (1LL << 28); + { + std::lock_guard lock(gen->mutex_); + auto philox_engine_inputs = + gen->philox_engine_inputs(total_elements_per_thread); + rand_seed = philox_engine_inputs.first; + rand_offset = philox_engine_inputs.second; + } + ptr_to_args[buffer_args.size()] = &rand_seed; + ptr_to_args[buffer_args.size() + 1] = &rand_offset; + } + + // Launch the kernels + auto stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_DRIVER_CHECK(nvrtc().cuLaunchKernel( + function_, + gpu_block_extents_v[0], + gpu_block_extents_v[1], + gpu_block_extents_v[2], + gpu_thread_extents_v[0], + gpu_thread_extents_v[1], + gpu_thread_extents_v[2], + 0, + stream, + ptr_to_args.data(), + nullptr)); + USE_TRIGGER(cuda_codegen_executed); +} + +void CudaCodeGen::CompileToNVRTC(const std::string& code) { + // Initializes driver's API context (if necessary) + CUdevice device = 0; + CUcontext pctx = 0; + AT_CUDA_DRIVER_CHECK(nvrtc().cuCtxGetCurrent(&pctx)); + if (!pctx) { + std::unique_lock cudaFreeMutexLock( + *(c10::cuda::CUDACachingAllocator::getFreeMutex())); + cudaFree(0); + } + + // Note: hacked at::DeviceGuard since at::DeviceGuard was failing to work + // properly in some scenarios + const auto prior_device = at::cuda::current_device(); + at::cuda::set_device(device); + + // Acquires device and NVRTC properties (for compile arch and occupancy + // calculations) + cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties(); + int major, minor; + getMajorMinor(prop, major, minor); + +#if DEBUG_PRINT + std::cout << "major: " << major << ", " + << "minor: " << minor << std::endl; +#endif + + // Creates the NVRTC program + nvrtcProgram program; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcCreateProgram( + &program, code.c_str(), nullptr, 0, nullptr, nullptr)); + +#ifdef __HIP_PLATFORM_HCC__ + std::vector args = {}; +#else + const std::string compute = "--gpu-architecture=compute_" + + std::to_string(major) + std::to_string(minor); + const std::vector args = { + "--std=c++14", compute.c_str(), "-default-device"}; +#endif + + const auto result = + nvrtc().nvrtcCompileProgram(program, args.size(), args.data()); + if (result != NVRTC_SUCCESS) { + size_t logsize; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLogSize(program, &logsize)); + std::vector log(logsize); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetProgramLog(program, log.data())); + std::stringstream cu; + cu << log.data() << std::endl; + cu << "nvrtc compilation failed: " << std::endl; + cu << code << std::endl; + throw std::runtime_error(cu.str()); + } + ResourceGuard holdProgram( + [&] { AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcDestroyProgram(&program)); }); + AT_CUDA_NVRTC_CHECK(result); + size_t ptx_size; + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTXSize(program, &ptx_size)); + std::vector ptx; + ptx.resize(ptx_size); + AT_CUDA_NVRTC_CHECK(nvrtc().nvrtcGetPTX(program, ptx.data())); + + CUmodule module; + std::string name = "f"; + AT_CUDA_DRIVER_CHECK(nvrtc().cuModuleLoadData(&module, ptx.data())); + AT_CUDA_DRIVER_CHECK( + nvrtc().cuModuleGetFunction(&function_, module, name.c_str())); +} + +RegisterCodeGen reg("cuda_codegen"); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_codegen.h b/torch/csrc/jit/tensorexpr/cuda_codegen.h new file mode 100644 index 0000000000000..df7fff2822b42 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_codegen.h @@ -0,0 +1,121 @@ +#pragma once + +#include +#include + +#include "ATen/ATen.h" +#include "ATen/cuda/CUDAContext.h" +#include "ATen/cuda/nvrtc_stub/ATenNVRTC.h" +#include "c10/cuda/CUDACachingAllocator.h" +#include "c10/cuda/CUDAGuard.h" +#include "torch/csrc/jit/resource_guard.h" +#include "torch/csrc/jit/tensorexpr/codegen.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/unique_name_manager.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// A class that overrides the underlying IRPrinter to produce Cuda C. +class CudaPrinter : public IRPrinter { + public: + explicit CudaPrinter(std::ostream* os, bool has_random) : IRPrinter(*os) { + if (has_random) { + rand_func_ = new Var("rand", kHandle); + } + } + + void visit(const Cast* v) override { + auto dtype = v->dtype(); + if (dtype == kHalf) { + os() << "half"; + } else { + os() << dtype; + } + os() << "("; + v->src_value()->accept(this); + os() << ")"; + } + + void visit(const Intrinsics* v); + void visit(const For* v); + + void visit(const Load* v) override; + void visit(const Store* v) override; + void visit(const Max* v) override; + void visit(const Min* v) override; + void visit(const LetStmt* v) override; + void visit(const IfThenElse* v) override; + + const std::vector& gpu_block_extents() const { + return gpu_block_extents_; + } + + const std::vector& gpu_thread_extents() const { + return gpu_thread_extents_; + } + + const Var* rand_func() const { + return rand_func_; + } + + using IRPrinter::name_manager; + + private: + std::vector gpu_block_extents_; + std::vector gpu_thread_extents_; + const Var* rand_func_; +}; + +// Construct Cuda C from the buffer and tensor input, and invoke the kernel +// when real arguments are provided. +class TORCH_API CudaCodeGen : public CodeGen { + public: + template + CudaCodeGen(Stmt* stmt, Ts... ts) + : CodeGen(stmt, std::forward(ts)...) { + Initialize(); + } + + CudaCodeGen(Stmt* stmt, const std::vector& buffer_args) + : CodeGen(stmt, buffer_args) { + Initialize(); + } + + ~CudaCodeGen() override {} + + TORCH_API void call(const std::vector& args) override; + + template + void operator()(const Ts&... ts) { + call(std::vector({CallArg(ts)...})); + } + + private: + TORCH_API void Initialize(); + + void CompileToNVRTC(const std::string& code); + + UniqueNameManager* name_manager() { + if (!printer_) { + throw std::runtime_error("Null IRPrinter is not expected"); + } + return printer_->name_manager(); + } + + std::ostream& os() { + return printer_->os(); + } + + std::ostringstream oss_; + std::unique_ptr printer_; + CUfunction function_; + bool has_random_ = false; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_half_support.h b/torch/csrc/jit/tensorexpr/cuda_half_support.h new file mode 100644 index 0000000000000..249c445117b93 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_half_support.h @@ -0,0 +1,31 @@ +#pragma once + +#include "torch/csrc/jit/tensorexpr/cuda_codegen.h" +#include "torch/csrc/jit/fuser/cuda/resource_strings.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// Walk the Statment looking for Half size loads/stores. +class CudaHalfChecker : public IRVisitor { + public: + bool hasHalf() { + return hasHalf_; + } + + void visit(const Load* v) override { + hasHalf_ |= v->dtype().scalar_type() == ScalarType::Half; + } + void visit(const Store* v) override { + hasHalf_ |= v->value()->dtype().scalar_type() == ScalarType::Half; + } + + private: + bool hasHalf_{false}; +}; + + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/cuda_random.h b/torch/csrc/jit/tensorexpr/cuda_random.h new file mode 100644 index 0000000000000..c8629ccaa9d9c --- /dev/null +++ b/torch/csrc/jit/tensorexpr/cuda_random.h @@ -0,0 +1,104 @@ +#pragma once + +namespace torch { +namespace jit { +namespace tensorexpr { + +constexpr auto philox_random_string = R"( + +class Philox { +public: + __device__ inline Philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + counter = make_uint4(0, 0, 0, 0); + counter.z = (unsigned int)(subsequence); + counter.w = (unsigned int)(subsequence >> 32); + STATE = 0; + incr_n(offset / 4); + } + + __device__ inline unsigned long operator()() { + if(STATE == 0) { + uint4 counter_ = counter; + uint2 key_ = key; + for(int i = 0; i < 9; i++) { + counter_ = single_round(counter_, key_); + key_.x += (kPhilox10A); key_.y += (kPhilox10B); + } + output = single_round(counter_, key_); + incr(); + } + unsigned long ret; + switch(STATE) { + case 0: ret = output.x; break; + case 1: ret = output.y; break; + case 2: ret = output.z; break; + case 3: ret = output.w; break; + } + STATE = (STATE + 1) % 4; + return ret; + } + +private: + uint4 counter; + uint4 output; + uint2 key; + unsigned int STATE; + __device__ inline void incr_n(unsigned long long n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + counter.x += nlo; + if (counter.x < nlo) + nhi++; + counter.y += nhi; + if (nhi <= counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ inline void incr() { + if (++counter.x) + return; + if (++counter.y) + return; + if (++counter.z) + return; + ++counter.w; + } + __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int *result_high) { + *result_high = __umulhi(a, b); + return a*b; + } + + __device__ inline uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; + } + + static const unsigned long kPhilox10A = 0x9E3779B9; + static const unsigned long kPhilox10B = 0xBB67AE85; + static const unsigned long kPhiloxSA = 0xD2511F53; + static const unsigned long kPhiloxSB = 0xCD9E8D57; +}; + +// Inverse of 2^32. +#define M_RAN_INVM32 2.3283064e-10f +__device__ __inline__ float Uint32ToFloat(unsigned int x) { + return x * M_RAN_INVM32; +} + +)"; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/eval.cpp b/torch/csrc/jit/tensorexpr/eval.cpp new file mode 100644 index 0000000000000..d41a2a343718c --- /dev/null +++ b/torch/csrc/jit/tensorexpr/eval.cpp @@ -0,0 +1,13 @@ +#include "torch/csrc/jit/tensorexpr/eval.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +DEFINE_TRIGGER(simple_ir_eval_executed); + +RegisterCodeGen reg("simple_ir_eval"); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/eval.h b/torch/csrc/jit/tensorexpr/eval.h new file mode 100644 index 0000000000000..4ebcf5b712577 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/eval.h @@ -0,0 +1,836 @@ +#pragma once + +#include +#include +#include + +#include +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/codegen.h" +#include "torch/csrc/jit/tensorexpr/execution_counter.h" +#include "torch/csrc/jit/tensorexpr/function.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" +#include "torch/csrc/jit/tensorexpr/types.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +DECLARE_TRIGGER(simple_ir_eval_executed); + +class Value { + public: + Value() : dtype_(kInt) { + Intvalues.push_back(0); + } + +#define VALUE_CTOR(Type, Name) \ + Value(Type v) : dtype_(k##Name) { \ + Name##values.push_back(v); \ + } +AT_FORALL_SCALAR_TYPES_AND(Half, VALUE_CTOR); +#undef VALUE_CTOR + +#define VALUE_VEC_CTOR(Type, Name) \ + Value(const std::vector& v) \ + : dtype_(Dtype(k##Name, v.size())), Name##values(v) {} +AT_FORALL_SCALAR_TYPES_AND(Half, VALUE_VEC_CTOR); +#undef VALUE_VEC_CTOR + + template + T as() const; + + template + const std::vector& as_vec() const; + + Dtype dtype() const { + return dtype_; + } + + private: + Dtype dtype_; + +#define VALUE_STORAGE(Type, Name) \ + std::vector Name##values; +AT_FORALL_SCALAR_TYPES_AND(Half, VALUE_STORAGE); +#undef VALUE_STORAGE + void* ptr; +}; + + +#define VALUE_AS_DISPATCH(Type, Name) \ + template <> \ + inline Type Value::as() const { \ + CHECK_EQ(dtype_, k##Name) << "invalid dtype"; \ + return Name##values[0];\ +} +AT_FORALL_SCALAR_TYPES_AND(Half, VALUE_AS_DISPATCH); +#undef VALUE_AS_DISPATCH + +#define VALUE_AS_VEC_DISPATCH(Type, Name) \ +template <> \ +inline const std::vector& Value::as_vec() const { \ + CHECK_EQ(dtype_.scalar_type(), ScalarType::Name) << "invalid dtype"; \ + return Name##values; \ +} +AT_FORALL_SCALAR_TYPES_AND(Half, VALUE_AS_VEC_DISPATCH); +#undef VALUE_AS_VEC_DISPATCH + +template +class PaddedBuffer; + +template +inline typename std::enable_if::value, T>::type mod_value( + T lhs, + T rhs) { + return lhs % rhs; +} + +template +inline typename std::enable_if::value, T>::type +mod_value(T lhs, T rhs) { + return std::fmod(lhs, rhs); +} + +inline bool mod_value(bool lhs, bool rhs) { + LOG(FATAL) << "Attempted modulus of bool"; + return false; +} + +class SimpleIREvaluator : public CodeGen, public IRVisitor { + public: + using CodeGen::CodeGen; + + ~SimpleIREvaluator() override {} + + TORCH_API void call(const std::vector& args) override { + CHECK_EQ(args.size(), buffer_args().size()); + for (size_t i = 0; i < args.size(); i++) { + bind(buffer_args()[i], args[i]); + } + stmt()->accept(this); + eval_context_.clear(); + buffer_mapping_.clear(); + internal_buffers_.clear(); + USE_TRIGGER(simple_ir_eval_executed); + } + + void bind(const BufferArg& buf, const CallArg& data) { + if (!buf.isVar()) { + buffer_mapping_[buf.var()] = data.data(); + return; + } + + switch (buf.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + eval_context_[buf.var()] = data.Name##Data(); \ + break; +AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unhandled dtype for argument " << buf.var()->name_hint() + << ": " << buf.dtype(); + } + } + + template + void operator()(const Ts&... ts) { + std::vector args({CallArg(ts)...}); + call(args); + } + + TORCH_API void visit(const Add* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Sub* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Mul* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Div* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Mod* v) override { + visit_binary_op(v); + } + TORCH_API void visit(const Max* v) override { + visit_binary_op(v, v->propagate_nans()); + } + TORCH_API void visit(const Min* v) override { + visit_binary_op(v, v->propagate_nans()); + } + + void visit(const CompareSelect* v) override { + visit_compare_select_op(v, v->compare_select_op()); + } + + template + Value binary_op( + const Value& lhs, + const Value& rhs, + IRNodeType op_type, + bool option = false) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (op_type) { + case IRNodeType::kAdd: + result_v[i] = lhs_v[i] + rhs_v[i]; + break; + case IRNodeType::kSub: + result_v[i] = lhs_v[i] - rhs_v[i]; + break; + case IRNodeType::kMul: + result_v[i] = lhs_v[i] * rhs_v[i]; + break; + case IRNodeType::kDiv: + result_v[i] = lhs_v[i] / rhs_v[i]; + break; + case IRNodeType::kMod: + result_v[i] = mod_value(lhs_v[i], rhs_v[i]); + break; + case IRNodeType::kMax: + if (option) { + // Propagate NaNs + if (is_floating_point(lhs.dtype().scalar_type()) && + is_floating_point(rhs.dtype().scalar_type())) { + result_v[i] = lhs_v[i]; + } else if (std::isnan((float)rhs_v[i])) { + result_v[i] = rhs_v[i]; + } + } else { + result_v[i] = lhs_v[i] > rhs_v[i] ? lhs_v[i] : rhs_v[i]; + } + break; + case IRNodeType::kMin: + if (option) { + // Propagate NaNs + if (is_floating_point(lhs.dtype().scalar_type()) && + is_floating_point(rhs.dtype().scalar_type())) { + result_v[i] = lhs_v[i]; + } else if (std::isnan((float)rhs_v[i])) { + result_v[i] = rhs_v[i]; + } + } else { + result_v[i] = lhs_v[i] < rhs_v[i] ? lhs_v[i] : rhs_v[i]; + } + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + Value bitwise_binary_op( + const Value& lhs, + const Value& rhs, + IRNodeType op_type) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (op_type) { + case IRNodeType::kAnd: + result_v[i] = lhs_v[i] & rhs_v[i]; + break; + case IRNodeType::kXor: + result_v[i] = lhs_v[i] ^ rhs_v[i]; + break; + case IRNodeType::kLshift: + result_v[i] = lhs_v[i] << rhs_v[i]; + break; + case IRNodeType::kRshift: + result_v[i] = lhs_v[i] >> rhs_v[i]; + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + Value compare_select_op( + const Value& lhs, + const Value& rhs, + const Value& retval1, + const Value& retval2, + CompareSelectOperation cmp_op) { + std::vector lhs_v = lhs.as_vec(); + std::vector rhs_v = rhs.as_vec(); + std::vector ret_val1_v = retval1.as_vec(); + std::vector ret_val2_v = retval2.as_vec(); + std::vector result_v(lhs_v.size()); + for (size_t i = 0; i < lhs_v.size(); i++) { + switch (cmp_op) { + case CompareSelectOperation::kEQ: + result_v[i] = (lhs_v[i] == rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kNE: + result_v[i] = (lhs_v[i] != rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kGT: + result_v[i] = (lhs_v[i] > rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kGE: + result_v[i] = (lhs_v[i] >= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kLT: + result_v[i] = (lhs_v[i] < rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + case CompareSelectOperation::kLE: + result_v[i] = (lhs_v[i] <= rhs_v[i]) ? ret_val1_v[i] : ret_val2_v[i]; + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } + return Value(result_v); + } + + template + void visit_binary_op(const BinaryOpNode* v, bool option = false) { + v->lhs()->accept(this); + Value lhs_v = value_; + v->rhs()->accept(this); + Value rhs_v = value_; + CHECK_EQ(lhs_v.dtype(), rhs_v.dtype()); + IRNodeType expr_type = v->expr_type(); + if (expr_type == IRNodeType::kAnd || expr_type == IRNodeType::kXor || + expr_type == IRNodeType::kLshift || expr_type == IRNodeType::kLshift) { + value_ = bitwise_binary_op(lhs_v, rhs_v, expr_type); + return; + } + + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = binary_op(lhs_v, rhs_v, expr_type); \ + break; +AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "invalid dtype: " << lhs_v.dtype(); + } + } + + void visit_compare_select_op( + const CompareSelect* v, + CompareSelectOperation cmp_op) { + v->lhs()->accept(this); + Value lhs_v = value_; + v->rhs()->accept(this); + Value rhs_v = value_; + v->ret_val1()->accept(this); + Value ret_val1_v = value_; + v->ret_val2()->accept(this); + Value ret_val2_v = value_; + + CHECK_EQ(lhs_v.dtype(), rhs_v.dtype()); + CHECK_EQ(ret_val1_v.dtype(), ret_val2_v.dtype()); + + switch (lhs_v.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + value_ = compare_select_op( \ + lhs_v, rhs_v, ret_val1_v, ret_val2_v, cmp_op); \ + break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "invalid dtype: " << lhs_v.dtype(); + } + } + +#define IMM_VISIT(Type, Name) \ + TORCH_API void visit(const Name##Imm* v) override { \ + value_ = Value(v->value()); \ + } +AT_FORALL_SCALAR_TYPES_AND(Half, IMM_VISIT); +#undef IMM_VISIT + + TORCH_API void visit(const Let* v) override { + const Var* var = dynamic_cast(v->var()); + CHECK(var != nullptr); + v->value()->accept(this); + Value value = value_; + auto iter = eval_context_.find(var); + // TODO: make the same value settable multiple times. + CHECK(iter == eval_context_.end()) + << "var must not exist in the context before"; + eval_context_[var] = value_; + + v->body()->accept(this); + + eval_context_.erase(var); + } + + TORCH_API void visit(const LetStmt* v) override { + const Var* var = v->var(); + CHECK(var != nullptr); + v->value()->accept(this); + Value value = value_; + auto iter = eval_context_.find(var); + // TODO: make the same value settable multiple times. + CHECK(iter == eval_context_.end()) + << "var must not exist in the context before"; + eval_context_[var] = value_; + + v->body()->accept(this); + + eval_context_.erase(var); + } + + TORCH_API void visit(const Var* v) override { + auto iter = eval_context_.find(v); + CHECK(iter != eval_context_.end()) + << "var must be defined in the context before"; + value_ = iter->second; + } + + template + std::vector castValues(const Dtype& src_dtype, const Value& v) { + const std::vector& src_values = v.as_vec(); + std::vector dst_values(src_values.size()); + for (int i = 0; i < src_dtype.lanes(); ++i) { + dst_values[i] = static_cast(src_values[i]); + } + return dst_values; + } + + template + void doCastFromSrc( + const Dtype& src_dtype, + const Dtype& dst_dtype, + const Value& v) { + switch (dst_dtype.scalar_type()) { +#define DST_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + this->value_ = Value(castValues(src_dtype, v)); \ + break; + AT_FORALL_SCALAR_TYPES_AND(Half, DST_TYPE_CASE); +#undef DST_TYPE_CASE + default: + LOG(FATAL) << "Cast invalid dst type " << dst_dtype << "\n"; + } + } + + TORCH_API void visit(const Cast* v) override { + const Expr* src_value = v->src_value(); + src_value->accept(this); + Dtype dst_dtype = v->dtype(); + Dtype src_dtype = src_value->dtype(); + CHECK_EQ(src_dtype.lanes(), dst_dtype.lanes()); + + if (src_dtype != dst_dtype) { + switch (src_dtype.scalar_type()) { +#define SRC_TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + doCastFromSrc(src_dtype, dst_dtype, value_); \ + break; + AT_FORALL_SCALAR_TYPES_AND(Half, SRC_TYPE_CASE); +#undef SRC_TYPE_CASE + default: + LOG(FATAL) << "Cast invalid src type " << src_dtype << "\n"; + } + } + } + + TORCH_API void visit(const For* v) override { + const Expr* var_node = v->var(); + v->start()->accept(this); + int start = value_.as(); + v->stop()->accept(this); + int stop = value_.as(); + auto iter = eval_context_.find(var_node); + CHECK(iter == eval_context_.end()) + << "var in For must not exist in eval context"; + for (int i = start; i < stop; i++) { + eval_context_[var_node] = Value(i); + if (v->body()) { + v->body()->accept(this); + } + } + eval_context_.erase(var_node); + } + + TORCH_API void visit(const Ramp* v) override { + v->base()->accept(this); + int base = value().as(); + v->stride()->accept(this); + int stride = value().as(); + int lanes = v->lanes(); + + std::vector values(lanes); + for (int i = 0; i < lanes; i++) { + values[i] = base + i * stride; + } + + value_ = Value(values); + } + + TORCH_API void visit(const Broadcast* v) override { + v->value()->accept(this); + Value value = this->value(); + int lanes = v->lanes(); + switch (value.dtype().scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + std::vector v(lanes, value.as()); \ + value_ = Value(v); \ + } break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "invalid dtype: " << value.dtype(); + } + } + + TORCH_API void visit(const IfThenElse* v) override { + v->condition()->accept(this); + if (value_.as()) { + v->true_value()->accept(this); + } else { + v->false_value()->accept(this); + } + } + + TORCH_API void visit(const Load* v) override { + const Var* base_node = v->base_handle(); + auto iter = buffer_mapping_.find(base_node); + CHECK(iter != buffer_mapping_.end()) + << "missing buffer binding: " << base_node->name_hint(); + void* ptr = iter->second; + + v->index()->accept(this); + std::vector index = value().as_vec(); + v->mask()->accept(this); + std::vector mask = value().as_vec(); + ScalarType v_sdtype = v->dtype().scalar_type(); + switch (v_sdtype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + Type* ptr##Name = static_cast(ptr); \ + std::vector v(index.size()); \ + for (size_t i = 0; i < index.size(); i++) { \ + if (mask[i]) { \ + v[i] = ptr##Name[index[i]]; \ + } \ + } \ + value_ = Value(v); \ + } break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Invalid dtype: " << v_sdtype; + } + } + + TORCH_API void visit(const Store* v) override { + const Var* base_node = v->base_handle(); + auto iter = buffer_mapping_.find(base_node); + CHECK(iter != buffer_mapping_.end()); + void* ptr = iter->second; + + v->index()->accept(this); + std::vector index = value().as_vec(); + v->mask()->accept(this); + std::vector mask = value().as_vec(); + CHECK_EQ(index.size(), mask.size()); + ScalarType v_sdtype = v->value()->dtype().scalar_type(); + + switch (v_sdtype) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + v->value()->accept(this); \ + std::vector value = this->value().as_vec(); \ + CHECK_EQ(index.size(), value.size()); \ + Type* ptr##Name = static_cast(ptr); \ + for (size_t i = 0; i < index.size(); i++) { \ + if (mask[i]) { \ + ptr##Name[index[i]] = value[i]; \ + } \ + } \ + } break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Invalid dtype: " << v_sdtype; + } + } + + TORCH_API void visit(const BaseCallNode* v) override { + LOG(FATAL) << "unsupported visit to BaseCallNode"; + } + + TORCH_API void visit(const Intrinsics* v) override { + std::vector values(v->nparams()); + for (int i = 0; i < v->nparams(); i++) { + v->param(i)->accept(this); + values[i] = this->value(); + } + std::vector v1; + if (values.size() >= 1ULL) { + v1 = values[0].as_vec(); + } + std::vector v2; + if (values.size() >= 2ULL) { + v2 = values[1].as_vec(); + CHECK_EQ(v1.size(), v2.size()) << "mismatch vectorize sizes"; + } + CHECK_LE(values.size(), 2ULL) + << "no support for intrinsics for more than two operand yet"; + std::vector result(v1.size(), -1); + if (values.size() == 1ULL) { + for (size_t i = 0; i < v1.size(); i++) { + result[i] = compute_intrinsics(v->op_type(), v1[i]); + } + } else { + for (size_t i = 0; i < v1.size(); i++) { + result[i] = compute_intrinsics(v->op_type(), v1[i], v2[i]); + } + } + value_ = Value(result); + } + + void visit(const Allocate* v) override { + const Var* buffer_var = v->buffer_var(); + std::vector dims = v->dims(); + int total_byte_size = v->dtype().byte_size(); + for (size_t i = 0; i < dims.size(); i++) { + dims[i]->accept(this); + total_byte_size *= value_.as(); + } + int int_count = (total_byte_size + sizeof(int) - 1) / sizeof(int); + std::unique_ptr> buffer(new std::vector(int_count)); + auto iter = buffer_mapping_.find(buffer_var); + if (iter != buffer_mapping_.end() && iter->second != nullptr) { + throw std::runtime_error( + "Allocate a buffer that has already been allocated: " + + buffer_var->name_hint()); + } + buffer_mapping_[buffer_var] = buffer->data(); + internal_buffers_.insert(std::make_pair(buffer_var, std::move(buffer))); + } + + void visit(const Free* v) override { + const Var* buffer_var = v->buffer_var(); + int count = internal_buffers_.erase(buffer_var); + if (count == 0) { + throw std::runtime_error( + "Free a buffer that is not currently bound: " + + buffer_var->name_hint()); + } + } + + void visit(const Cond* v) override { + v->condition()->accept(this); + if (value().as()) { + if (v->true_stmt()) { + v->true_stmt()->accept(this); + } + } else { + if (v->false_stmt()) { + v->false_stmt()->accept(this); + } + } + } + + Value value() const { + return value_; + } + + private: + static float compute_intrinsics(IntrinsicsOp op_type, float v) { + switch (op_type) { + case kSin: + return std::sin(v); + case kCos: + return std::cos(v); + case kTan: + return std::tan(v); + case kAsin: + return std::asin(v); + case kAcos: + return std::acos(v); + case kAtan: + return std::atan(v); + case kSinh: + return std::sinh(v); + case kCosh: + return std::cosh(v); + case kTanh: + return std::tanh(v); + case kExp: + return std::exp(v); + case kFabs: + return std::fabs(v); + case kExpm1: + return std::expm1(v); + case kLog: + return std::log(v); + case kLog2: + return std::log2(v); + case kLog10: + return std::log10(v); + case kLog1p: + return std::log1p(v); + case kErf: + return std::erf(v); + case kErfc: + return std::erfc(v); + case kSqrt: + return std::sqrt(v); + case kRsqrt: + return 1.0f / std::sqrt(v); + case kCeil: + return std::ceil(v); + case kFloor: + return std::floor(v); + case kRound: + return std::round(v); + case kTrunc: + return std::trunc(v); + case kLgamma: + return std::lgamma(v); + case kFrac: + float intpart; + return std::modf(v, &intpart); + default: + throw std::runtime_error("invalid op_type: " + std::to_string(op_type)); + } + } + + static float compute_intrinsics(IntrinsicsOp op_type, float v1, float v2) { + switch (op_type) { + case kPow: + return std::pow(v1, v2); + case kFmod: + return std::fmod(v1, v2); + case kRemainder: + return std::remainderf(v1, v2); + case kAtan2: + return std::atan2(v1, v2); + default: + throw std::runtime_error("nvalid op_type: " + std::to_string(op_type)); + } + } + + Value value_; + std::unordered_map eval_context_; + std::unordered_map buffer_mapping_; + std::unordered_map>> + internal_buffers_; +}; + +using VarMapping = std::vector>; + +class VarSubMutator : public IRMutator { + public: + VarSubMutator(const VarMapping& var_mapping) { + for (const auto& entry : var_mapping) { + const ExprHandle& key = entry.first; + const ExprHandle& value = entry.second; + const Var* key_var = key.AsNode(); + CHECK(key_var != nullptr); + var_mapping_[key_var] = value; + } + } + + const Expr* mutate(const Var* var) override { + auto iter = var_mapping_.find(var); + if (iter == var_mapping_.end()) { + return const_cast(var); + } + return iter->second.node(); + } + + private: + std::unordered_map var_mapping_; +}; + +template +class ExprEval { + public: + using BufferArg = CodeGen::BufferArg; + using CallArg = CodeGen::CallArg; + + template + ExprEval(const ExprHandle& expr, Ts... ts) : ExprEval(expr, {BufferArg(ts)...}) {} + + ExprEval(const ExprHandle& expr, const std::vector& buffer_args) + : dtype_(expr.dtype()) { + std::vector buffer_args_extended = buffer_args; + Buffer ret_buf("ret_val", dtype_, {1}); + Stmt* store_stmt = Store::make(VarHandle(ret_buf.data()), 0, expr); + buffer_args_extended.push_back(ret_buf); + codegen_.reset(new CodeGenType(store_stmt, buffer_args_extended)); + } + + template + void operator()(Ts... ts) { + call(ts...); + } + + void operator()(const std::vector& call_args) { + call(call_args); + } + + template + void call(Ts... ts) { + call({CallArg(ts)...}); + } + + void call(const std::vector& call_args) { + std::vector call_args_extended = call_args; + switch (dtype_.scalar_type()) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: { \ + std::vector ret_val_arg(1); \ + call_args_extended.push_back(CallArg(ret_val_arg)); \ + codegen_->call(call_args_extended); \ + ret_value_ = Value(ret_val_arg[0]); \ + } break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw std::runtime_error("Invalid dtype"); + } + } + + template + T value(Ts... ts) { + call(std::forward(ts)...); + return ret_value_.as(); + } + + Dtype dtype() { return dtype_; } + + private: + Dtype dtype_; + std::unique_ptr codegen_; + Value ret_value_; +}; + +inline ExprHandle Substitute(ExprHandle* expr, const VarMapping& var_mapping) { + VarSubMutator var_sub(var_mapping); + return ExprHandle(expr->node()->accept_mutator(&var_sub)); +} + +inline Stmt* Substitute(Stmt* stmt, const VarMapping& var_mapping) { + VarSubMutator var_sub(var_mapping); + return stmt->accept_mutator(&var_sub); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/execution_counter.h b/torch/csrc/jit/tensorexpr/execution_counter.h new file mode 100644 index 0000000000000..7377b62a2ef23 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/execution_counter.h @@ -0,0 +1,118 @@ +#pragma once + +#include "torch/csrc/WindowsTorchApiMacro.h" + +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +/* +ExecutionTrigger and ExecutionCounter builds instrumentation counters so +underlying functionalities can be checked. + +In the code to be instrumented: + +// worker.cpp +DEFINE_TRIGGER(useful_work_done); // this defines a trigger "useful_work_done" +void run() { + USE_TRIGGER(useful_work_done); // this triggers the underlying counter + // in "useful_work_done" +} + +// in C++ client.cpp + +DECLARE_TRIGGER(useful_work_done); // Optional: this declares a trigger that + // will be defined elsewhere +ExecutionCounter counter(useful_work_done); // This starts the counter from the + // underlying trigger. +... call run() ... +counter.elapsed_value(); // this returns the incremented value from the + // trigger since the creation of the counter + +// in Python client.py +counter = ExecutionCounter("useful_work_done") // this starts the counter from + // the underlying trigger +... call C++ run() ... +counter.elapsed_value() // This returns the incremented value from the + // trigger since the creation of the counter. +*/ + +class ExecutionTrigger; +class ExecutionTriggerList { + public: + TORCH_API static ExecutionTriggerList& GetInstance() { + static ExecutionTriggerList instance; + return instance; + } + + ExecutionTrigger* FindByName(const std::string& name) const { + auto iter = trigger_list_.find(name); + if (iter == trigger_list_.end()) { + throw std::runtime_error("Invalid trigger name: " + name); + } + return iter->second; + } + + private: + friend class ExecutionTrigger; + + ExecutionTriggerList() {} + ExecutionTriggerList(const ExecutionTriggerList&) = delete; + ExecutionTriggerList& operator=(const ExecutionTriggerList&) = delete; + + void AddTrigger(const std::string& name, ExecutionTrigger* trigger) { + auto insert_ret = trigger_list_.insert(std::make_pair(name, trigger)); + if (!insert_ret.second) { + throw std::runtime_error("Duplicated trigger name: " + name); + } + } + + std::unordered_map trigger_list_; +}; + +class ExecutionTrigger { + public: + explicit ExecutionTrigger(const std::string& name) : name_(name) { + ExecutionTriggerList::GetInstance().AddTrigger(name, this); + } + + int value() const { + return value_; + } + + void trigger() { + value_++; + } + + private: + ExecutionTrigger(const ExecutionTrigger&) = delete; + ExecutionTrigger& operator=(const ExecutionTrigger&) = delete; + int value_ = 0; + const std::string name_; +}; + +class ExecutionCounter { + public: + explicit ExecutionCounter(ExecutionTrigger& trigger) : trigger_(trigger) { + start_value_ = trigger_.value(); + } + + int elapsed_value() const { + return trigger_.value() - start_value_; + } + + private: + ExecutionTrigger& trigger_; + int start_value_ = 0; +}; + +#define DEFINE_TRIGGER(name) TORCH_API ExecutionTrigger name(#name) +#define DECLARE_TRIGGER(name) TORCH_API extern ExecutionTrigger name +#define USE_TRIGGER(name) (name).trigger() + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp new file mode 100644 index 0000000000000..2bd8aaef7edfd --- /dev/null +++ b/torch/csrc/jit/tensorexpr/expr.cpp @@ -0,0 +1,200 @@ +#include "torch/csrc/jit/tensorexpr/expr.h" + +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +ExprHandle ExprHandle::operator+(const ExprHandle& other) const { + return Add::make(*this, other); +} + +ExprHandle ExprHandle::operator-(const ExprHandle& other) const { + return Sub::make(*this, other); +} + +ExprHandle ExprHandle::operator*(const ExprHandle& other) const { + return Mul::make(*this, other); +} + +ExprHandle ExprHandle::operator/(const ExprHandle& other) const { + return Div::make(*this, other); +} + +ExprHandle ExprHandle::operator%(const ExprHandle& other) const { + return Mod::make(*this, other); +} + +ExprHandle ExprHandle::operator==(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kEQ); +} + +ExprHandle ExprHandle::operator!=(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kNE); +} + +ExprHandle ExprHandle::operator>(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kGT); +} + +ExprHandle ExprHandle::operator>=(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kGE); +} + +ExprHandle ExprHandle::operator<(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kLT); +} + +ExprHandle ExprHandle::operator<=(const ExprHandle& other) const { + return CompareSelect::make(*this, other, CompareSelectOperation::kLE); +} + +ExprHandle ExprHandle::operator&(const ExprHandle& other) const { + return And::make(*this, other); +} + +ExprHandle ExprHandle::operator^(const ExprHandle& other) const { + return Xor::make(*this, other); +} + +ExprHandle ExprHandle::operator<<(const ExprHandle& other) const { + return Lshift::make(*this, other); +} + +ExprHandle ExprHandle::operator>>(const ExprHandle& other) const { + return Rshift::make(*this, other); +} + +#define IMM_EXPR_DECLARE(Type, Name) \ + ExprHandle::ExprHandle(Type v) : ExprHandle(Name##Imm::make(v)) {} +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE); +#undef IMM_EXPR_DECLARE + +ExprHandle sin(const ExprHandle& v) { + return Intrinsics::make(kSin, v); +} + +ExprHandle cos(const ExprHandle& v) { + return Intrinsics::make(kCos, v); +} + +ExprHandle tan(const ExprHandle& v) { + return Intrinsics::make(kTan, v); +} + +ExprHandle asin(const ExprHandle& v) { + return Intrinsics::make(kAsin, v); +} + +ExprHandle acos(const ExprHandle& v) { + return Intrinsics::make(kAcos, v); +} + +ExprHandle atan(const ExprHandle& v) { + return Intrinsics::make(kAtan, v); +} + +ExprHandle sinh(const ExprHandle& v) { + return Intrinsics::make(kSinh, v); +} + +ExprHandle cosh(const ExprHandle& v) { + return Intrinsics::make(kCosh, v); +} + +ExprHandle tanh(const ExprHandle& v) { + return Intrinsics::make(kTanh, v); +} + +ExprHandle exp(const ExprHandle& v) { + return Intrinsics::make(kExp, v); +} + +ExprHandle expm1(const ExprHandle& v) { + return Intrinsics::make(kExpm1, v); +} + +ExprHandle fabs(const ExprHandle& v) { + return Intrinsics::make(kFabs, v); +} + +ExprHandle log(const ExprHandle& v) { + return Intrinsics::make(kLog, v); +} + +ExprHandle log2(const ExprHandle& v) { + return Intrinsics::make(kLog2, v); +} + +ExprHandle log10(const ExprHandle& v) { + return Intrinsics::make(kLog10, v); +} + +ExprHandle log1p(const ExprHandle& v) { + return Intrinsics::make(kLog1p, v); +} + +ExprHandle erf(const ExprHandle& v) { + return Intrinsics::make(kErf, v); +} + +ExprHandle erfc(const ExprHandle& v) { + return Intrinsics::make(kErfc, v); +} + +ExprHandle sqrt(const ExprHandle& v) { + return Intrinsics::make(kSqrt, v); +} + +ExprHandle rsqrt(const ExprHandle& v) { + return Intrinsics::make(kRsqrt, v); +} + +ExprHandle ceil(const ExprHandle& v) { + return Intrinsics::make(kCeil, v); +} + +ExprHandle floor(const ExprHandle& v) { + return Intrinsics::make(kFloor, v); +} + +ExprHandle round(const ExprHandle& v) { + return Intrinsics::make(kRound, v); +} + +ExprHandle trunc(const ExprHandle& v) { + return Intrinsics::make(kTrunc, v); +} + +ExprHandle frac(const ExprHandle& v) { + return Intrinsics::make(kFrac, v); +} + +ExprHandle lgamma(const ExprHandle& v) { + return Intrinsics::make(kLgamma, v); +} + +ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2) { + return Intrinsics::make(kAtan2, v1, v2); +} + +ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2) { + return Intrinsics::make(kPow, v1, v2); +} + +ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2) { + return Intrinsics::make(kFmod, v1, v2); +} + +ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2) { + return Intrinsics::make(kRemainder, v1, v2); +} + +ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f) { + return IfThenElse::make(c, t, f); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h new file mode 100644 index 0000000000000..d14ee7bebd9b6 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/expr.h @@ -0,0 +1,203 @@ +/** + * This file implements the core classes for Tensor Expressions. + * + * The structure of the expressions is inspired by Halide/TVM IR. + */ +#pragma once + +#include "torch/csrc/jit/tensorexpr/ir_mutator.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/types.h" +#include "torch/csrc/jit/tensorexpr/mem_arena.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// The common base between all expression node. +class Expr : public KernelScopedObject { + public: + explicit Expr(Dtype dtype) : dtype_(dtype) {} + Dtype dtype() const { + return dtype_; + } + TORCH_API virtual void accept(IRVisitor* visitor) const = 0; + virtual const Expr* accept_mutator(IRMutator* mutator) const = 0; + + private: + Dtype dtype_; +}; + +// A CRTP pattern to accept visitors for children class, +// and dispatch back to the children. +template +class ExprNode : public Base { + public: + using ExprNodeBase = ExprNode; + void accept(IRVisitor* visitor) const override { + visitor->visit(static_cast(this)); + } + const Expr* accept_mutator(IRMutator* mutator) const override; + // pass the constructor to the base class + using Base::Base; +}; + +// A wrapper object to the underlying ExprNode. +// Also serves the primary way to build and operate on other expressions. +class TORCH_API ExprHandle { + public: + ExprHandle() {} + explicit ExprHandle(const Expr* node) + : base_expr_node_(const_cast(node)) {} + + Expr* node() { + return base_expr_node_; + } + + const Expr* node() const { + return base_expr_node_; + } + + bool empty() const { + return base_expr_node_ == nullptr; + } + +#define IMM_EXPR_DECLARE(Type, Name) \ + ExprHandle(Type v); +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_EXPR_DECLARE); +#undef IMM_EXPR_DECLARE + + template + Op* AsNode() { + return dynamic_cast(this->node()); + } + + template + const Op* AsNode() const { + return const_cast(this)->AsNode(); + } + + Dtype dtype() const { + return node()->dtype(); + } + + // Handling the math operators. + ExprHandle operator+(const ExprHandle& other) const; + ExprHandle operator-(const ExprHandle& other) const; + ExprHandle operator*(const ExprHandle& other) const; + ExprHandle operator/(const ExprHandle& other) const; + ExprHandle operator%(const ExprHandle& other) const; + ExprHandle operator==(const ExprHandle& other) const; + ExprHandle operator!=(const ExprHandle& other) const; + ExprHandle operator>(const ExprHandle& other) const; + ExprHandle operator>=(const ExprHandle& other) const; + ExprHandle operator<(const ExprHandle& other) const; + ExprHandle operator<=(const ExprHandle& other) const; + ExprHandle operator&(const ExprHandle& other) const; + ExprHandle operator^(const ExprHandle& other) const; + ExprHandle operator<<(const ExprHandle& other) const; + ExprHandle operator>>(const ExprHandle& other) const; + + private: + Expr* base_expr_node_ = nullptr; +}; + +// The underlying representation node to a Var. +// Currently, each Var object represents a unique variable, even though the +// names might be the same. We should consider add a unique_name as well. +class Var : public ExprNode { + public: + static ExprHandle make(const std::string& name_hint, Dtype dtype) { + return ExprHandle(new Var(name_hint, dtype)); + } + static ExprHandle make(Dtype dtype) { + return ExprHandle(new Var("", dtype)); + } + + // TODO: unique_name + const std::string& name_hint() const { + return name_hint_; + } + + Var(const std::string& name_hint, Dtype dtype) + : ExprNodeBase(dtype), name_hint_(name_hint) {} + + private: + std::string name_hint_; +}; + +// An expression to construct the underlying variable node. +// Note: do not store any info here, since it is often possible to slice this +// object. For example: VarHandle x('x'); ExprHandle x2 = x; +class VarHandle : public ExprHandle { + public: + VarHandle() : ExprHandle(nullptr) {} + explicit VarHandle(Dtype dtype) : ExprHandle(Var::make(dtype)) {} + VarHandle(const std::string& name_hint, Dtype dtype) + : ExprHandle(Var::make(name_hint, dtype)) {} + explicit VarHandle(const Var* node) : ExprHandle(node) {} + const Var* node() const { + return static_cast(ExprHandle::node()); + } + bool operator==(const VarHandle& other) const { + return this->node() == other.node(); + } + bool operator!=(const VarHandle& other) const { + return !(*this == other); + } + + const std::string& name_hint() const { + return this->node()->name_hint(); + } + bool empty() const { + return (this->node() == nullptr); + } +}; + +template +const Expr* ExprNode::accept_mutator(IRMutator* mutator) const { + ExprNode* this_mutable = const_cast(this); + return mutator->mutate(static_cast(this_mutable)); +} + +inline bool same_node(const ExprHandle& expr1, const ExprHandle& expr2) { + return expr1.AsNode() == expr2.AsNode(); +} + +TORCH_API ExprHandle sin(const ExprHandle& v); +TORCH_API ExprHandle cos(const ExprHandle& v); +TORCH_API ExprHandle tan(const ExprHandle& v); +TORCH_API ExprHandle asin(const ExprHandle& v); +TORCH_API ExprHandle acos(const ExprHandle& v); +TORCH_API ExprHandle atan(const ExprHandle& v); +TORCH_API ExprHandle sinh(const ExprHandle& v); +TORCH_API ExprHandle cosh(const ExprHandle& v); +TORCH_API ExprHandle tanh(const ExprHandle& v); +TORCH_API ExprHandle exp(const ExprHandle& v); +TORCH_API ExprHandle expm1(const ExprHandle& v); +TORCH_API ExprHandle fabs(const ExprHandle& v); +TORCH_API ExprHandle log(const ExprHandle& v); +TORCH_API ExprHandle log2(const ExprHandle& v); +TORCH_API ExprHandle log10(const ExprHandle& v); +TORCH_API ExprHandle log1p(const ExprHandle& v); +TORCH_API ExprHandle erf(const ExprHandle& v); +TORCH_API ExprHandle erfc(const ExprHandle& v); +TORCH_API ExprHandle sqrt(const ExprHandle& v); +TORCH_API ExprHandle rsqrt(const ExprHandle& v); +TORCH_API ExprHandle ceil(const ExprHandle& v); +TORCH_API ExprHandle floor(const ExprHandle& v); +TORCH_API ExprHandle round(const ExprHandle& v); +TORCH_API ExprHandle trunc(const ExprHandle& v); +TORCH_API ExprHandle frac(const ExprHandle& v); +TORCH_API ExprHandle lgamma(const ExprHandle& v); +TORCH_API ExprHandle atan2(const ExprHandle& v1, const ExprHandle& v2); +TORCH_API ExprHandle pow(const ExprHandle& v1, const ExprHandle& v2); +TORCH_API ExprHandle fmod(const ExprHandle& v1, const ExprHandle& v2); +TORCH_API ExprHandle remainder(const ExprHandle& v1, const ExprHandle& v2); + +TORCH_API ExprHandle ifThenElse(const ExprHandle& c, const ExprHandle& t, const ExprHandle& f); + + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/function.cpp b/torch/csrc/jit/tensorexpr/function.cpp new file mode 100644 index 0000000000000..6c4b7c10891fc --- /dev/null +++ b/torch/csrc/jit/tensorexpr/function.cpp @@ -0,0 +1,152 @@ +#include "torch/csrc/jit/tensorexpr/function.h" + +#include +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +namespace { + +static void unpack_dim_args( + const std::vector& dim_args, + std::vector* dims, + std::vector* vars) { + dims->clear(); + vars->clear(); + for (size_t i = 0; i < dim_args.size(); i++) { + dims->push_back(dim_args[i].dim().node()); + vars->push_back(new Var(dim_args[i].name_hint(), kInt)); + } +} + +} // namespace + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function&)> body_func) { + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarVectorToVarHandleVector(args)).node(); + Function* func = new Function( + func_name, std::move(dims), std::move(args), std::move(body)); + return new Tensor(func, 0); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func) { + CHECK_EQ(dim_args.size(), 1ULL); + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarHandle(args[0])).node(); + Function* func = new Function( + func_name, std::move(dims), std::move(args), std::move(body)); + return new Tensor(func, 0); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func) { + CHECK_EQ(dim_args.size(), 2ULL); + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = body_func(VarHandle(args[0]), VarHandle(args[1])).node(); + Function* func = new Function( + func_name, std::move(dims), std::move(args), std::move(body)); + return new Tensor(func, 0); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function< + ExprHandle(const VarHandle&, const VarHandle&, const VarHandle&)> + body_func) { + CHECK_EQ(dim_args.size(), 3ULL); + std::vector dims; + std::vector args; + unpack_dim_args(dim_args, &dims, &args); + const Expr* body = + body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2])) + .node(); + Function* func = new Function( + func_name, std::move(dims), std::move(args), std::move(body)); + return new Tensor(func, 0); +} + +Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func) { + CHECK_EQ(dim_args.size(), 4ULL); + std::vector dims; + std::vector args_nodes; + unpack_dim_args(dim_args, &dims, &args_nodes); + auto args = VarVectorToVarHandleVector(args_nodes); + const Expr* body = body_func(args[0], args[1], args[2], args[3]).node(); + Function* func = new Function( + func_name, std::move(dims), std::move(args_nodes), std::move(body)); + return new Tensor(func, 0); +} + +Stmt* Function::ElementStmt(size_t index) { + std::vector strides(dims_.size()); + auto* ce = dynamic_cast(body(index)); + if (ce != nullptr) { + std::vector input_vars; + std::vector input_args; + for (auto p : ce->params()) { + auto fc = dynamic_cast(p); + if (fc) { + input_vars.emplace_back(fc->tensor()->function()->func_var(index)); + } else { + input_args.emplace_back(p); + } + } + return OpaqueCall::make( + ce->name(), func_var(index), input_vars, input_args); + } + for (size_t i = 0; i < strides.size(); i++) { + if (i == strides.size() - 1) { + strides[i] = ExprHandle(1); + continue; + } + ExprHandle stride = ExprHandle(dims_[i + 1]); + for (size_t j = i + 2; j < dims_.size(); j++) { + stride = stride * ExprHandle(dims_[j]); + } + strides[i] = stride; + } + + ExprHandle total_index; + for (size_t i = 0; i < dims_.size(); i++) { + ExprHandle index = VarHandle(this->args_[i]) * ExprHandle(strides[i]); + if (i == 0) { + total_index = index; + } else { + total_index = total_index + index; + } + } + + const Expr* mask = new IntImm(1); + + Stmt* update_stmt = + new Store(func_var(index), total_index.node(), body(index), mask); + return update_stmt; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/function.h b/torch/csrc/jit/tensorexpr/function.h new file mode 100644 index 0000000000000..a5b87d471384d --- /dev/null +++ b/torch/csrc/jit/tensorexpr/function.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include + +#include "torch/csrc/jit/tensorexpr/expr.h" +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +// represent a range [start, stop) +class Range { + public: + Range() {} + Range(const ExprHandle& start, const ExprHandle& stop) : start_(start), stop_(stop) {} + const ExprHandle& start() const { + return start_; + } + const ExprHandle& stop() const { + return stop_; + } + + private: + ExprHandle start_; + ExprHandle stop_; +}; + +class Function : public KernelScopedObject { + public: + Function( + const std::string& func_name, + const std::vector& dims, + const std::vector& args, + const Expr* body) + : func_vars_({VarHandle(func_name, kHandle).node()}), dims_(dims), args_(args), bodies_({body}) {} + Function( + const std::vector& func_names, + const std::vector& dims, + const std::vector& args, + const std::vector& bodies) + : func_vars_(func_names.size()), dims_(dims), args_(args), bodies_(bodies) { + for (size_t i = 0; i < func_names.size(); i++) { + func_vars_[i] = new Var(func_names[i], kHandle); + } + } + + int ndim() const { + return dims_.size(); + } + const Expr* dim(int index) const { + CHECK_GE(index, 0) << "index out of lower bound"; + CHECK_LT(index, ndim()) << "index out of upper bound"; + return dims_[index]; + } + const std::vector& dims() const { + return dims_; + } + const Var* arg(int index) const { + CHECK_GE(index, 0) << "index out of lower bound"; + CHECK_LT(index, ndim()) << "index out of upper bound"; + return args_[index]; + } + const std::vector& args() const { + return args_; + } + + std::vector bodies() const { + return bodies_; + } + const Expr* body(size_t index) const { + CHECK(index < bodies_.size()); + return bodies_[index]; + } + + std::vector func_vars() const { + return func_vars_; + } + const Var* func_var(size_t index) const { + CHECK(index < func_vars_.size()); + return func_vars_[index]; + } + + Stmt* ElementStmt(size_t index); + + private: + std::vector func_vars_; + std::vector dims_; + std::vector args_; + std::vector bodies_; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir.cpp b/torch/csrc/jit/tensorexpr/ir.cpp new file mode 100644 index 0000000000000..085d1496fff28 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir.cpp @@ -0,0 +1,149 @@ +#include "torch/csrc/jit/tensorexpr/ir.h" + +#include "torch/csrc/jit/tensorexpr/buffer.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +static Dtype ChooseDtype(const Dtype& buffer_dtype, const Dtype& index_dtype) { + return Dtype(buffer_dtype, index_dtype.lanes()); +} + +Load::Load(const Buffer& buffer, const Expr* index, const Expr* mask) + : Load( + ChooseDtype(buffer.dtype(), index->dtype()), + buffer.data(), + index, + mask) {} + +Load::Load( + Dtype dtype, + const Var* base_handle, + const Expr* index, + const Expr* mask) + : ExprNodeBase(dtype), + base_handle_(base_handle), + index_(index), + mask_(mask) { + CHECK_EQ(base_handle_->dtype(), kHandle); + CHECK_EQ(index->dtype().lanes(), mask->dtype().lanes()); + CHECK_EQ(index->dtype().scalar_type(), ScalarType::Int); +} + +Store::Store( + const Buffer& buffer, + const Expr* index, + const Expr* value, + const Expr* mask) + : Store(buffer.data(), index, value, mask) { + CHECK_EQ(buffer.dtype().scalar_type(), value->dtype().scalar_type()); + CHECK_EQ(buffer.dtype().scalar_type(), value->dtype().scalar_type()); +} + +Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1) { + // TODO: check the op_type and make a real decision + return dt1; +} + +Dtype Intrinsics::IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1, Dtype dt2) { + // TODO: check the op_type and make a real decision + return dt1; +} + +Dtype Intrinsics::IntrinsicsDtype( + IntrinsicsOp op_type, + const std::vector& params) { + // TODO: check the op_type an dmake a real decision + CHECK_GE(params.size(), 1ULL); + return params[0]->dtype(); +} + +Dtype CallExternal::CallExternalDtype( + std::string name, + const std::vector& params) { + // TODO: check the op_type an dmake a real decision + CHECK_GE(params.size(), 1ULL); + return params[0]->dtype(); +} + +int Intrinsics::OpArgCount(IntrinsicsOp op_type) { + switch (op_type) { + case kSin: + case kCos: + case kTan: + case kAsin: + case kAcos: + case kAtan: + case kSinh: + case kCosh: + case kTanh: + case kExp: + case kExpm1: + case kFabs: + case kLog: + case kLog2: + case kLog10: + case kLog1p: + case kErf: + case kErfc: + case kSqrt: + case kRsqrt: + case kCeil: + case kFloor: + case kRound: + case kTrunc: + case kFrac: + case kLgamma: + return 1; + case kRand: + return 0; + case kAtan2: + case kFmod: + case kPow: + case kRemainder: + return 2; + default: + throw std::runtime_error("invalid op_type: " + std::to_string(op_type)); + } +} + +std::vector ExprHandleVectorToExprVector( + const std::vector& v) { + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); i++) { + result[i] = v[i].node(); + } + return result; +} + +std::vector ExprVectorToExprHandleVector( + const std::vector& v) { + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); i++) { + result[i] = ExprHandle(v[i]); + } + return result; +} + +std::vector VarHandleVectorToVarVector( + const std::vector& v) { + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); i++) { + result[i] = v[i].node(); + } + return result; +} + +std::vector VarVectorToVarHandleVector( + const std::vector& v) { + std::vector result(v.size()); + for (size_t i = 0; i < v.size(); i++) { + result[i] = VarHandle(v[i]); + } + return result; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h new file mode 100644 index 0000000000000..9868df606b5af --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir.h @@ -0,0 +1,775 @@ +#pragma once + +#include +#include + +#include "torch/csrc/jit/tensorexpr/expr.h" +#include "torch/csrc/jit/tensorexpr/stmt.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +enum IRNodeType { + kAdd, + kSub, + kMul, + kDiv, + kMod, + kMax, + kMin, + kAnd, + kLshift, + kRshift, + kXor, + kCompareSelect, +}; + +enum CompareSelectOperation { + kEQ, + kGT, + kGE, + kLT, + kLE, + kNE, +}; + +class Buffer; + +class Cast : public ExprNode { + public: + const Expr* src_value() const { + return src_value_; + } + static ExprHandle make(Dtype dtype, const ExprHandle& src_value) { + return ExprHandle(new Cast(dtype, src_value.node())); + } + Cast(Dtype dtype, const Expr* src_value) + : ExprNodeBase(dtype), src_value_(src_value) {} + + private: + const Expr* src_value_; +}; + +template +ExprHandle cast(const ExprHandle& src_value) { + return Cast::make(Dtype(ToDtype(), src_value.dtype().lanes()), src_value); +} + +// Represent the expression node for binary operators. +// A CRTP pattern to share common code among the operators. +template +class BinaryOpNode : public ExprNode { + public: + const Expr* lhs() const { + return this->lhs_; + } + const Expr* rhs() const { + return this->rhs_; + } + IRNodeType expr_type() const { + return expr_type_; + } + + static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) { + return ExprHandle(new Op(lhs.node(), rhs.node())); + } + + BinaryOpNode( + const Expr* lhs_v, + const Expr* rhs_v, + IRNodeType expr_type, + ScalarType ret_type = ScalarType::None) + : ExprNode(BinaryOpDtype(lhs_v->dtype(), rhs_v->dtype(), ret_type)), + lhs_(CastIfNeeded(lhs_v, ExprNode::dtype())), + rhs_(CastIfNeeded(rhs_v, ExprNode::dtype())), + expr_type_(expr_type) {} + + private: + static const Expr* CastIfNeeded(const Expr* expr, Dtype dst_dtype) { + if (expr->dtype() == dst_dtype) { + return expr; + } + return Cast::make(dst_dtype, ExprHandle(expr)).node(); + } + + const Expr* lhs_; + const Expr* rhs_; + IRNodeType expr_type_; +}; + +class Add : public BinaryOpNode { + public: + Add(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kAdd) {} +}; + +class Sub : public BinaryOpNode { + public: + Sub(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kSub) {} +}; + +class Mul : public BinaryOpNode { + public: + Mul(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kMul) {} +}; + +class Div : public BinaryOpNode
{ + public: + Div(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kDiv) {} +}; + +class Mod : public BinaryOpNode { + public: + Mod(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kMod) {} +}; + +class And : public BinaryOpNode { + public: + And(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kAnd) { + CHECK_EQ(lhs->dtype().scalar_type(), ScalarType::Int); + CHECK_EQ(lhs->dtype(), rhs->dtype()); + } +}; + +class Xor : public BinaryOpNode { + public: + Xor(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kXor) { + CHECK_EQ(lhs->dtype().scalar_type(), ScalarType::Int); + CHECK_EQ(lhs->dtype(), rhs->dtype()); + } +}; + +class Lshift : public BinaryOpNode { + public: + Lshift(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kLshift) { + CHECK_EQ(lhs->dtype().scalar_type(), ScalarType::Int); + CHECK_EQ(lhs->dtype(), rhs->dtype()); + } +}; + +class Rshift : public BinaryOpNode { + public: + Rshift(const Expr* lhs, const Expr* rhs) + : BinaryOpNode(lhs, rhs, IRNodeType::kRshift) { + CHECK_EQ(lhs->dtype().scalar_type(), ScalarType::Int); + CHECK_EQ(lhs->dtype(), rhs->dtype()); + } +}; + +class Max : public BinaryOpNode { + private: + bool propagate_nans_; + + public: + Max(const Expr* lhs, const Expr* rhs, bool propagate_nans) + : BinaryOpNode(lhs, rhs, IRNodeType::kMax), + propagate_nans_(propagate_nans) {} + + bool propagate_nans() const { + return propagate_nans_; + } + + static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete; + static ExprHandle make( + const ExprHandle& lhs, + const ExprHandle& rhs, + bool propagate_nans) { + return ExprHandle(new Max(lhs.node(), rhs.node(), propagate_nans)); + } +}; + +class Min : public BinaryOpNode { + private: + bool propagate_nans_; + + public: + Min(const Expr* lhs, const Expr* rhs, bool propagate_nans) + : BinaryOpNode(lhs, rhs, IRNodeType::kMin), + propagate_nans_(propagate_nans) {} + + bool propagate_nans() const { + return propagate_nans_; + } + + static ExprHandle make(const ExprHandle& lhs, const ExprHandle& rhs) = delete; + static ExprHandle make( + const ExprHandle& lhs, + const ExprHandle& rhs, + bool propagate_nans) { + return ExprHandle(new Min(lhs.node(), rhs.node(), propagate_nans)); + } +}; + +// Encode typed immediate values e.g. IntImm, FloatImm. +#define IMM_DECLARE(Type, Name) \ + class Name##Imm : public ExprNode { \ + public: \ + Name##Imm(Type value) : ExprNodeBase(k##Name), value_(value) {} \ + Type value() const { \ + return value_; \ + } \ + static ExprHandle make(Type value) { \ + return ExprHandle(new Name##Imm(value)); \ + } \ + \ + private: \ + Type value_; \ + }; +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); +#undef IMM_DECLARE + +// Bind the value to the var and evaluate the body. +class Let : public ExprNode { + public: + const Expr* var() const { + return var_; + } + const Expr* value() const { + return value_; + } + const Expr* body() const { + return body_; + } + + static ExprHandle make( + const ExprHandle& var, + const ExprHandle& value, + const ExprHandle& body) { + return ExprHandle(new Let(var.node(), value.node(), body.node())); + } + + Let(const Expr* var, const Expr* value, const Expr* body) + : ExprNodeBase(body->dtype()), var_(var), value_(value), body_(body) {} + + private: + const Expr* var_; + const Expr* value_; + const Expr* body_; +}; + +// Represents a ramp vector node: +// [base, base + 1 * stride, ... , base + (lanes - 1) * stride] +class Ramp : public ExprNode { + public: + const Expr* base() const { + return base_; + } + const Expr* stride() const { + return stride_; + } + static ExprHandle make( + const ExprHandle& base, + const ExprHandle& stride, + int lanes) { + return ExprHandle(new Ramp(base.node(), stride.node(), lanes)); + } + int lanes() const { + return lanes_; + } + + Ramp(const Expr* base, const Expr* stride, int lanes) + : ExprNodeBase(Dtype(base->dtype(), lanes)), + base_(base), + stride_(stride), + lanes_(lanes) { + CHECK_EQ(stride->dtype(), base->dtype()); + } + + private: + const Expr* base_; + const Expr* stride_; + int lanes_; +}; + +class TORCH_API Load : public ExprNode { + public: + const Var* base_handle() const { + return base_handle_; + } + const Expr* index() const { + return index_; + } + const Expr* mask() const { + return mask_; + } + static ExprHandle make( + const Buffer& buffer, + const ExprHandle& index, + const ExprHandle& mask) { + return ExprHandle(new Load(buffer, index.node(), mask.node())); + } + static ExprHandle make( + Dtype dtype, + const VarHandle& base_handle, + const ExprHandle& index, + const ExprHandle& mask) { + return ExprHandle( + new Load(dtype, base_handle.node(), index.node(), mask.node())); + } + + Load(const Buffer& buffer, const Expr* index, const Expr* mask); + Load( + Dtype dtype, + const Var* base_handle, + const Expr* index, + const Expr* mask); + + private: + const Var* base_handle_; + const Expr* index_; + const Expr* mask_; +}; + +class TORCH_API OpaqueCall : public StmtNode { + public: + const std::string name() const { + return name_; + } + + const Var* output_handle() const { + return output_handle_; + } + + const std::vector& input_handles() const { + return input_handles_; + } + + const std::vector& arguments() const { + return arguments_; + } + + static Stmt* make( + const std::string& name, + const Var* output_handle, + const std::vector& input_handles, + const std::vector& arguments) { + return new OpaqueCall(name, output_handle, input_handles, arguments); + } + + private: + OpaqueCall( + const std::string& name, + const Var* output_handle, + const std::vector& input_handles, + const std::vector& arguments) + : name_(name), + output_handle_(output_handle), + input_handles_(input_handles), + arguments_(arguments) {} + + std::string name_; + const Var* output_handle_; + std::vector input_handles_; + std::vector arguments_; +}; + +class Broadcast : public ExprNode { + public: + const Expr* value() const { + return value_; + } + int lanes() const { + return lanes_; + } + static ExprHandle make(const ExprHandle& value, int lanes) { + return ExprHandle(new Broadcast(value.node(), lanes)); + } + Broadcast(const Expr* value, int lanes) + : ExprNodeBase(Dtype(value->dtype(), lanes)), + value_(value), + lanes_(lanes) {} + + private: + const Expr* value_; + int lanes_; +}; + +class IfThenElse : public ExprNode { + public: + const Expr* condition() const { + return condition_; + } + + // Lazily evaluated only if condition is true + const Expr* true_value() const { + return true_; + } + + // Lazily evaluated only if condition is false + const Expr* false_value() const { + return false_; + } + + static ExprHandle make( + const ExprHandle& c, + const ExprHandle& t, + const ExprHandle& f) { + return ExprHandle(new IfThenElse(c.node(), t.node(), f.node())); + } + + IfThenElse(const Expr* c, const Expr* t, const Expr* f) + : ExprNodeBase(t->dtype()), condition_(c), true_(t), false_(f) { + CHECK_EQ(c->dtype().scalar_type(), ScalarType::Int); + CHECK_EQ(c->dtype().lanes(), 1); + CHECK_EQ(t->dtype(), f->dtype()); + } + + private: + const Expr* condition_; + const Expr* true_; + const Expr* false_; +}; + +class BaseCallNode : public Expr { + public: + enum CallType { + kIntrinsics, + kCallExternal, + kFunctionCall, + }; + + int nparams() const { + return params_.size(); + } + + const Expr* param(int index) const { + return params_[index]; + } + const std::vector& params() const { + return params_; + } + + virtual std::string func_name() const = 0; + + CallType call_type() const { + return call_type_; + } + + protected: + BaseCallNode( + Dtype dtype, + CallType call_type, + const std::vector& params) + : Expr(dtype), call_type_(call_type), params_(params) {} + + private: + // The handler for the default ir_mutator to make a copy of this node with new + // params. + virtual const Expr* DefaultMutator( + const std::vector& new_params) const = 0; + + template + friend class ExprNode; + friend class IRMutator; + + CallType call_type_; + std::vector params_; +}; + +template +class CallNode : public ExprNode { + public: + using BaseClass = ExprNode; + using BaseClass::BaseClass; +}; + +class TORCH_API CompareSelect : public ExprNode { + public: + CompareSelectOperation compare_select_op() const { + return compare_op_; + } + const Expr* lhs() const { + return this->lhs_; + } + const Expr* rhs() const { + return this->rhs_; + } + const Expr* ret_val1() const { + return this->ret_val1_; + } + const Expr* ret_val2() const { + return this->ret_val2_; + } + + static ExprHandle make( + const ExprHandle& lhs, + const ExprHandle& rhs, + CompareSelectOperation cmp_op) { + CHECK_EQ(lhs.dtype(), rhs.dtype()); + return ExprHandle(new CompareSelect( + lhs.node(), + rhs.node(), + IntImm::make(1).node(), + IntImm::make(0).node(), + cmp_op)); + } + + static ExprHandle make( + const ExprHandle& lhs, + const ExprHandle& rhs, + const ExprHandle& ret_val1, + const ExprHandle& ret_val2, + CompareSelectOperation cmp_op) { + CHECK_EQ(lhs.dtype(), rhs.dtype()); + CHECK_EQ(ret_val1.dtype(), ret_val2.dtype()); + return ExprHandle(new CompareSelect( + lhs.node(), rhs.node(), ret_val1.node(), ret_val2.node(), cmp_op)); + } + + private: + const Expr* lhs_; + const Expr* rhs_; + const Expr* ret_val1_; + const Expr* ret_val2_; + CompareSelectOperation compare_op_; + CompareSelect( + const Expr* lhs, + const Expr* rhs, + const Expr* ret_val1, + const Expr* ret_val2, + CompareSelectOperation cmp_op) + : ExprNodeBase(ToDtype()), + lhs_(lhs), + rhs_(rhs), + ret_val1_(ret_val1), + ret_val2_(ret_val2), + compare_op_(cmp_op) {} +}; + +enum IntrinsicsOp { + kSin, + kCos, + kTan, + kAsin, + kAcos, + kAtan, + kAtan2, + kSinh, + kCosh, + kTanh, + kExp, + kExpm1, + kFabs, + kLog, + kLog2, + kLog10, + kLog1p, + kErf, + kErfc, + kSqrt, + kRsqrt, + kPow, + kCeil, + kFloor, + kRound, + kTrunc, + kFmod, + kRemainder, + kLgamma, + kFrac, + kRand, // We need more discussions on this. Should we consider stateful? +}; + +class CallExternal : public CallNode { + public: + static const Expr* make( + std::string name, + const std::vector& params) { + return new CallExternal(name, params); + } + std::string func_name() const override { + return name_; + } + inline std::string name() const { + return name_; + } + const Expr* DefaultMutator( + const std::vector& new_params) const override { + return CallExternal::make(name_, new_params); + } + + private: + using BaseClass = CallNode; + CallExternal(std::string name, const std::vector& params) + : BaseClass(CallExternalDtype(name, params), kCallExternal, params), + name_(name), + params_(params) {} + TORCH_API static Dtype CallExternalDtype( + std::string name, + const std::vector& params); + + std::string name_; + const std::vector& params_; +}; + +class Intrinsics : public CallNode { + public: + static ExprHandle make(IntrinsicsOp op_type, const ExprHandle& v1) { + return ExprHandle(new Intrinsics(op_type, v1.node())); + } + + static ExprHandle make( + IntrinsicsOp op_type, + const ExprHandle& v1, + const ExprHandle& v2) { + return ExprHandle(new Intrinsics(op_type, v1.node(), v2.node())); + } + + static ExprHandle make( + IntrinsicsOp op_type, + const std::vector& params) { + std::vector params_nodes(params.size()); + for (size_t i = 0; i < params.size(); i++) { + params_nodes[i] = params[i].node(); + } + return ExprHandle(new Intrinsics(op_type, params_nodes)); + } + + static ExprHandle make(IntrinsicsOp op_type, Dtype dtype) { + return ExprHandle(new Intrinsics(op_type, dtype)); + } + + IntrinsicsOp op_type() const { + return op_type_; + } + + std::string func_name() const override { + switch (op_type()) { + case kSin: + return "sin"; + case kCos: + return "cos"; + case kTan: + return "tan"; + case kAsin: + return "asin"; + case kAcos: + return "acos"; + case kAtan: + return "atan"; + case kAtan2: + return "atan2"; + case kSinh: + return "sinh"; + case kCosh: + return "cosh"; + case kTanh: + return "tanh"; + case kExp: + return "exp"; + case kFabs: + return "fabs"; + case kLog: + return "log"; + case kLog2: + return "log2"; + case kLog10: + return "log10"; + case kLog1p: + return "log1p"; + case kErf: + return "erf"; + case kSqrt: + return "sqrt"; + case kRsqrt: + return "rsqrt"; + case kPow: + return "pow"; + case kCeil: + return "ceil"; + case kFloor: + return "floor"; + case kRound: + return "round"; + case kTrunc: + return "trunc"; + case kRand: + return "rand"; + case kFmod: + return "fmod"; + case kRemainder: + return "remainder"; + case kLgamma: + return "lgamma"; + case kExpm1: + return "expm1"; + case kErfc: + return "erfc"; + case kFrac: + return "frac"; + default: + throw std::runtime_error( + "invalid op_type: " + std::to_string(op_type())); + } + } + using BaseClass = CallNode; + + Intrinsics(IntrinsicsOp op_type, Dtype dtype) + : BaseClass(IntrinsicsDtype(op_type, dtype), kIntrinsics, {}), + op_type_(op_type) { + CHECK_EQ(OpArgCount(op_type), 0); + } + + Intrinsics(IntrinsicsOp op_type, const Expr* v1) + : BaseClass(IntrinsicsDtype(op_type, v1->dtype()), kIntrinsics, {v1}), + op_type_(op_type) { + CHECK_EQ(OpArgCount(op_type), 1); + } + + Intrinsics(IntrinsicsOp op_type, const Expr* v1, const Expr* v2) + : BaseClass( + IntrinsicsDtype(op_type, v1->dtype(), v2->dtype()), + kIntrinsics, + {v1, v2}), + op_type_(op_type) { + CHECK_EQ(OpArgCount(op_type), 2); + } + + Intrinsics(IntrinsicsOp op_type, const std::vector& params) + : BaseClass(IntrinsicsDtype(op_type, params), kIntrinsics, params), + op_type_(op_type) { + CHECK_EQ(OpArgCount(op_type), nparams()); + } + + private: + TORCH_API static int OpArgCount(IntrinsicsOp op_type); + + const Expr* DefaultMutator( + const std::vector& new_params) const override { + return new Intrinsics(this->op_type(), new_params); + } + + TORCH_API static Dtype IntrinsicsDtype(IntrinsicsOp op_type, Dtype dt1); + TORCH_API static Dtype IntrinsicsDtype( + IntrinsicsOp op_type, + Dtype dt1, + Dtype dt2); + TORCH_API static Dtype IntrinsicsDtype( + IntrinsicsOp op_type, + const std::vector& params); + + IntrinsicsOp op_type_; +}; + +class FunctionCall; + +TORCH_API std::vector ExprHandleVectorToExprVector( + const std::vector&); +TORCH_API std::vector ExprVectorToExprHandleVector( + const std::vector&); +TORCH_API std::vector VarHandleVectorToVarVector( + const std::vector&); +TORCH_API std::vector VarVectorToVarHandleVector( + const std::vector&); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.cpp b/torch/csrc/jit/tensorexpr/ir_mutator.cpp new file mode 100644 index 0000000000000..d8768fd762994 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_mutator.cpp @@ -0,0 +1,370 @@ +#include "torch/csrc/jit/tensorexpr/ir_mutator.h" + +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +static const Expr* mutate_binary_op( + const BinaryOpNode* v, + IRMutator* mutator, + bool option = false) { + const Expr* lhs = v->lhs(); + const Expr* rhs = v->rhs(); + const Expr* lhs_new = lhs->accept_mutator(mutator); + const Expr* rhs_new = rhs->accept_mutator(mutator); + if (lhs == lhs_new && rhs == rhs_new) { + return v; + } + IRNodeType expr_type = v->expr_type(); + switch (expr_type) { + case IRNodeType::kAdd: + return new Add(lhs_new, rhs_new); + case IRNodeType::kSub: + return new Sub(lhs_new, rhs_new); + case IRNodeType::kMul: + return new Mul(lhs_new, rhs_new); + case IRNodeType::kDiv: + return new Div(lhs_new, rhs_new); + case IRNodeType::kMod: + return new Mod(lhs_new, rhs_new); + case IRNodeType::kMax: + return new Max(lhs_new, rhs_new, option); + case IRNodeType::kMin: + return new Min(lhs_new, rhs_new, option); + case IRNodeType::kAnd: + return new And(lhs_new, rhs_new); + case IRNodeType::kXor: + return new Xor(lhs_new, rhs_new); + case IRNodeType::kLshift: + return new Lshift(lhs_new, rhs_new); + case IRNodeType::kRshift: + return new Rshift(lhs_new, rhs_new); + default: + LOG(FATAL) << "unsupported expr_type: " << static_cast(expr_type); + return nullptr; + } +} + +const Expr* IRMutator::mutate(const Add* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Sub* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Mul* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Div* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Mod* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const And* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Xor* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Lshift* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Rshift* v) { + return mutate_binary_op(v, this); +} + +const Expr* IRMutator::mutate(const Max* v) { + return mutate_binary_op(v, this, v->propagate_nans()); +} + +const Expr* IRMutator::mutate(const Min* v) { + return mutate_binary_op(v, this, v->propagate_nans()); +} + +const Expr* IRMutator::mutate(const CompareSelect* v) { + const Expr* lhs = v->lhs(); + const Expr* rhs = v->rhs(); + const Expr* retval1 = v->ret_val1(); + const Expr* retval2 = v->ret_val2(); + const Expr* lhs_new = lhs->accept_mutator(this); + const Expr* rhs_new = rhs->accept_mutator(this); + const Expr* retval1_new = retval1->accept_mutator(this); + const Expr* retval2_new = retval2->accept_mutator(this); + if (lhs == lhs_new && rhs == rhs_new && retval1 == retval1_new && + retval2 == retval2_new) { + return v; + } + return CompareSelect::make( + ExprHandle(lhs_new), + ExprHandle(rhs_new), + ExprHandle(retval1_new), + ExprHandle(retval2_new), + v->compare_select_op()) + .node(); +} + +#define IMM_MUTATE_DEFINE(_1, Name) \ + const Expr* IRMutator::mutate(const Name##Imm* v) { \ + return v; \ + } +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DEFINE); +#undef IMM_MUTATE_DEFINE + +const Expr* IRMutator::mutate(const Cast* v) { + const Expr* src_value = v->src_value(); + const Expr* src_value_new = src_value->accept_mutator(this); + if (src_value_new == v->src_value()) { + return v; + } + return new Cast(v->dtype(), src_value_new); +} + +const Expr* IRMutator::mutate(const Var* v) { + return v; +} + +const Expr* IRMutator::mutate(const Let* v) { + const Expr* var = v->var(); + const Expr* value = v->value(); + const Expr* body = v->body(); + const Expr* var_new = var->accept_mutator(this); + const Expr* value_new = value->accept_mutator(this); + const Expr* body_new = body->accept_mutator(this); + if ((var == var_new) && (value == value_new) && (body == body_new)) { + return v; + } + return new Let(var_new, value_new, body_new); +} + +Stmt* IRMutator::mutate(const LetStmt* v) { + const Var* var = v->var(); + const Expr* value = v->value(); + Stmt* body = v->body(); + const Var* var_new = dynamic_cast(var->accept_mutator(this)); + if (var_new == nullptr) { + throw std::runtime_error("LetStmt var must be variable"); + } + const Expr* value_new = value->accept_mutator(this); + Stmt* body_new = body->accept_mutator(this); + if ((var == var_new) && (value == value_new) && (body == body_new)) { + return (Stmt*)v; + } + return new LetStmt(var_new, value_new, body_new); +} + +const Expr* IRMutator::mutate(const Ramp* v) { + const Expr* base = v->base(); + const Expr* stride = v->stride(); + const Expr* base_new = base->accept_mutator(this); + const Expr* stride_new = stride->accept_mutator(this); + if (base == base_new && stride == stride_new) { + return v; + } + return new Ramp(base_new, stride_new, v->lanes()); +} + +const Expr* IRMutator::mutate(const Load* v) { + Dtype dtype = v->dtype(); + const Var* base_handle = v->base_handle(); + const Expr* index = v->index(); + const Expr* mask = v->mask(); + const Expr* base_handle_expr = base_handle->accept_mutator(this); + const Var* base_handle_new = dynamic_cast(base_handle_expr); + const Expr* index_new = index->accept_mutator(this); + const Expr* mask_new = mask->accept_mutator(this); + if (base_handle == base_handle_new && index == index_new && + mask == mask_new) { + return v; + } + return new Load(dtype, base_handle_new, index_new, mask_new); +} + +const Expr* IRMutator::mutate(const Broadcast* v) { + const Expr* value = v->value(); + int lanes = v->lanes(); + const Expr* value_new = value->accept_mutator(this); + if (value == value_new) { + return v; + } + return new Broadcast(value_new, lanes); +} + +const Expr* IRMutator::mutate(const IfThenElse* v) { + const Expr* condition = v->condition(); + const Expr* true_value = v->true_value(); + const Expr* false_value = v->false_value(); + const Expr* condition_new = condition->accept_mutator(this); + const Expr* true_value_new = true_value->accept_mutator(this); + const Expr* false_value_new = false_value->accept_mutator(this); + if (condition == condition_new && true_value == true_value_new && + false_value == false_value_new) { + return v; + } + + return new IfThenElse(condition_new, true_value_new, false_value_new); +} + +const Expr* IRMutator::mutate(const Intrinsics* v) { + const BaseCallNode* base = v; + return this->mutate(base); +} + +const Expr* IRMutator::mutate(const FunctionCall* v) { + const BaseCallNode* base = v; + return this->mutate(base); +} + +const Expr* IRMutator::mutate(const BaseCallNode* v) { + std::vector params(v->nparams()); + bool any_change = false; + for (int i = 0; i < v->nparams(); i++) { + const Expr* value = v->param(i); + const Expr* value_new = value->accept_mutator(this); + if (value != value_new) { + any_change = true; + } + params[i] = std::move(value_new); + } + if (!any_change) { + return v; + } + return v->DefaultMutator(params); +} + +Stmt* IRMutator::mutate(const For* v) { + const Expr* var = v->var(); + const Expr* start = v->start(); + const Expr* stop = v->stop(); + Stmt* body = v->body(); + LoopOptions loop_options = v->loop_options(); + const Expr* var_new_expr = var->accept_mutator(this); + const Var* var_new = dynamic_cast(var_new_expr); + const Expr* start_new = start->accept_mutator(this); + const Expr* stop_new = stop->accept_mutator(this); + Stmt* body_new = body->accept_mutator(this); + if (!body_new) { + return nullptr; + } + if (var == var_new && start == start_new && stop == stop_new && + body == body_new) { + return (Stmt*)v; + } + return new For(var_new, start_new, stop_new, body_new, loop_options); +} + +Stmt* IRMutator::mutate(const Block* v) { + bool any_change = false; + std::vector stmts; + for (int i = 0; i < v->nstmts(); i++) { + Stmt* stmt = v->stmt(i); + Stmt* stmt_new = stmt->accept_mutator(this); + if (stmt != stmt_new) { + any_change = true; + } + if (stmt_new) { + stmts.push_back(stmt_new); + } + } + if (!any_change) { + return (Stmt*)v; + } + return Block::make(stmts); +} + +Stmt* IRMutator::mutate(const Store* v) { + const Var* base_handle = v->base_handle(); + const Expr* index = v->index(); + const Expr* value = v->value(); + const Expr* mask = v->mask(); + const Expr* base_handle_expr = base_handle->accept_mutator(this); + const Var* base_handle_new = dynamic_cast(base_handle_expr); + const Expr* index_new = index->accept_mutator(this); + const Expr* value_new = value->accept_mutator(this); + const Expr* mask_new = mask->accept_mutator(this); + if (base_handle == base_handle_new && index == index_new && + value == value_new && mask == mask_new) { + return (Stmt*)v; + } + return new Store(base_handle_new, index_new, value_new, mask_new); +} + +Stmt* IRMutator::mutate(const OpaqueCall* v) { + const Var* output_handle = v->output_handle(); + std::vector input_handles = v->input_handles(); + std::vector arguments = v->arguments(); + const Var* output_handle_new = + dynamic_cast(output_handle->accept_mutator(this)); + std::vector input_handles_new; + for (auto ih : input_handles) { + input_handles_new.emplace_back( + dynamic_cast(ih->accept_mutator(this))); + } + std::vector arguments_new; + for (auto a : arguments) { + arguments_new.emplace_back(a->accept_mutator(this)); + } + // TODO: if same_node checks + return OpaqueCall::make( + v->name(), output_handle_new, input_handles_new, arguments_new); +} + +Stmt* IRMutator::mutate(const Allocate* v) { + const Var* buffer_var_old = v->buffer_var(); + const Var* buffer_var_new = + dynamic_cast(buffer_var_old->accept_mutator(this)); + bool any_change = buffer_var_new == buffer_var_old; + std::vector dims_old = v->dims(); + std::vector dims_new(dims_old.size()); + for (size_t i = 0; i < dims_old.size(); i++) { + dims_new[i] = dims_old[i]->accept_mutator(this); + any_change |= (dims_new[i] == dims_old[i]); + } + + if (!any_change) { + return (Stmt*)v; + } + + return new Allocate(buffer_var_new, v->dtype(), dims_new); +} + +Stmt* IRMutator::mutate(const Free* v) { + const Expr* buffer_var_old = v->buffer_var(); + const Var* buffer_var_new = + dynamic_cast(buffer_var_old->accept_mutator(this)); + if (buffer_var_new == buffer_var_old) { + return (Stmt*)v; + } + + return new Free(buffer_var_new); +} + +Stmt* IRMutator::mutate(const Cond* v) { + const Expr* cond_old = v->condition(); + Stmt* true_old = v->true_stmt(); + Stmt* false_old = v->false_stmt(); + + const Expr* cond_new = cond_old->accept_mutator(this); + Stmt* true_new = true_old ? true_old->accept_mutator(this) : true_old; + Stmt* false_new = false_old ? false_old->accept_mutator(this) : false_old; + + if (cond_old == cond_new && true_old == true_new && false_old == false_new) { + return (Stmt*)v; + } + return new Cond(cond_new, true_new, false_new); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_mutator.h b/torch/csrc/jit/tensorexpr/ir_mutator.h new file mode 100644 index 0000000000000..623e89437b970 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_mutator.h @@ -0,0 +1,99 @@ +#pragma once +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +class Add; +class Sub; +class Mul; +class Div; +class Mod; +class Max; +class Min; +class And; +class Xor; +class Lshift; +class Rshift; +class CompareSelect; + +#define IMM_DECLARE(Type, Name) class Name##Imm; +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE); +#undef IMM_DECLARE + +class Cast; +class Var; +class Let; +class LetStmt; +class Ramp; +class Load; +class For; +class Block; +class Store; +class OpaqueCall; +class Broadcast; +class IfThenElse; +class ExprHandle; +class Expr; +class BaseCallNode; +class Intrinsics; +class CallExternal; +class FunctionCall; +class Allocate; +class Free; +class Cond; +class Stmt; + +class TORCH_API IRMutator { + public: + virtual ~IRMutator() {} + virtual const Expr* mutate(const Add* v); + virtual const Expr* mutate(const Sub* v); + virtual const Expr* mutate(const Mul* v); + virtual const Expr* mutate(const Div* v); + virtual const Expr* mutate(const Mod* v); + virtual const Expr* mutate(const Max* v); + virtual const Expr* mutate(const Min* v); + virtual const Expr* mutate(const And* v); + virtual const Expr* mutate(const Xor* v); + virtual const Expr* mutate(const Lshift* v); + virtual const Expr* mutate(const Rshift* v); + virtual const Expr* mutate(const CompareSelect* v); +#define IMM_MUTATE_DECLARE(Type, Name) \ + virtual const Expr* mutate(const Name##Imm* v); + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_MUTATE_DECLARE); +#undef IMM_MUTATE_DECLARE + virtual const Expr* mutate(const Cast* v); + virtual const Expr* mutate(const Var* v); + virtual const Expr* mutate(const Let* v); + virtual Stmt* mutate(const LetStmt* v); + virtual const Expr* mutate(const Ramp* v); + virtual const Expr* mutate(const Load* v); + virtual const Expr* mutate(const Broadcast* v); + virtual const Expr* mutate(const IfThenElse* v); + + // BaseCallNode is the base class for all call nodes. + // For any visitors that only needs the common behavior, only override this + // function is enough. This is because all derived class handlers will call + // this function by default. + // Override the derived class handler only if the logic is more specific to + // that. + virtual const Expr* mutate(const BaseCallNode* v); + virtual const Expr* mutate(const Intrinsics* v); + virtual const Expr* mutate(const FunctionCall* v); + + virtual Stmt* mutate(const For* v); + virtual Stmt* mutate(const Block* v); + virtual Stmt* mutate(const Store* v); + virtual Stmt* mutate(const OpaqueCall* v); + + virtual Stmt* mutate(const Allocate* v); + virtual Stmt* mutate(const Free* v); + virtual Stmt* mutate(const Cond* v); +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_printer.cpp b/torch/csrc/jit/tensorexpr/ir_printer.cpp new file mode 100644 index 0000000000000..34831bdf9f3a7 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_printer.cpp @@ -0,0 +1,360 @@ +#include "torch/csrc/jit/tensorexpr/ir_printer.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +void IRPrinter::print(ExprHandle expr) { + expr.node()->accept(this); +} + +void IRPrinter::print(const Expr& expr) { + expr.accept(this); +} + +void IRPrinter::print(const Stmt& stmt) { + stmt.accept(this); +} + +// TODO: change whether to include the parenthesis to the parent expression, +// we need to look at the operator precedence to make the output simpler. +template +void visitBinaryOp( + const BinaryOpNode* v, + const std::string& op_str, + IRPrinter* printer) { + std::ostream& os = printer->os(); + os << "("; + v->lhs()->accept(printer); + os << " " << op_str << " "; + v->rhs()->accept(printer); + os << ")"; +} + +void IRPrinter::visit(const Add* v) { + visitBinaryOp(v, "+", this); +} + +void IRPrinter::visit(const Sub* v) { + visitBinaryOp(v, "-", this); +} + +void IRPrinter::visit(const Mul* v) { + visitBinaryOp(v, "*", this); +} + +void IRPrinter::visit(const Div* v) { + visitBinaryOp(v, "/", this); +} + +void IRPrinter::visit(const And* v) { + visitBinaryOp(v, "&", this); +} + +void IRPrinter::visit(const Xor* v) { + visitBinaryOp(v, "^", this); +} + +void IRPrinter::visit(const Lshift* v) { + visitBinaryOp(v, "<<", this); +} + +void IRPrinter::visit(const Rshift* v) { + visitBinaryOp(v, ">>", this); +} + +void IRPrinter::visit(const Mod* v) { + if (v->dtype().is_integral()) { + visitBinaryOp(v, "%", this); + } else if (v->dtype().is_floating_point()) { + os() << "mod(" << v->lhs() << ", " << v->rhs() << ")"; + } else { + throw std::runtime_error("invalid dtype: " + std::to_string(v->dtype())); + } +} + +void IRPrinter::visit(const Max* v) { + os() << "Max("; + v->lhs()->accept(this); + os() << ", "; + v->rhs()->accept(this); + os() << ", " << (unsigned int)v->propagate_nans() << ")"; +} + +void IRPrinter::visit(const Min* v) { + os() << "Min("; + v->lhs()->accept(this); + os() << ", "; + v->rhs()->accept(this); + os() << ", " << (unsigned int)v->propagate_nans() << ")"; +} + +void IRPrinter::visit(const CompareSelect* v) { + CompareSelectOperation cmp_op = v->compare_select_op(); + os() << "("; + v->lhs()->accept(this); + switch (cmp_op) { + case CompareSelectOperation::kEQ: + os() << "=="; + break; + case CompareSelectOperation::kNE: + os() << "!="; + break; + case CompareSelectOperation::kGT: + os() << ">"; + break; + case CompareSelectOperation::kGE: + os() << ">="; + break; + case CompareSelectOperation::kLT: + os() << "<"; + break; + case CompareSelectOperation::kLE: + os() << "<="; + break; + default: + throw std::runtime_error("invalid compare select operator"); + } + v->rhs()->accept(this); + os() << ")"; +} + +#define IMM_PRINT_VISIT(Type, Name) \ + void IRPrinter::visit(const Name##Imm* v) { \ + if (v->dtype().is_floating_point()) { \ + std::ostringstream oss; \ + oss << v->value(); \ + std::string s = oss.str(); \ + if (s.find('.') == std::string::npos) { \ + s += ".f"; \ + } else { \ + s += "f"; \ + } \ + os() << s; \ + } else { \ + os() << v->value(); \ + } \ + } +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); +#undef IMM_PRINT_VISIT + +void IRPrinter::visit(const Cast* v) { + auto dtype = v->dtype(); + os() << dtype << "("; + v->src_value()->accept(this); + os() << ")"; +} + +void IRPrinter::visit(const Var* v) { + os() << name_manager_.get_unique_name(v); +} + +void IRPrinter::visit(const Let* v) { + os() << "(let "; + v->var()->accept(this); + os() << " = "; + v->value()->accept(this); + os() << " in "; + v->body()->accept(this); + os() << ")"; +} + +void IRPrinter::visit(const LetStmt* v) { + const Var* var = v->var(); + os() << var->dtype().ToCppString() << " " << *var << " = " << *v->value() + << "; " << std::endl; + v->body()->accept(this); +} + +void IRPrinter::visit(const Ramp* v) { + emitIndent(); + os() << "Ramp(" << v->base() << ", " << v->stride() << ", " << v->lanes() + << ")"; +} + +void IRPrinter::visit(const Load* v) { + // TODO: support the mask case + os() << *v->base_handle() << "[" << *v->index() << "]"; +} + +void IRPrinter::visit(const For* v) { + const Var* var = v->var(); + VarHandle vv(var); + emitIndent(); + os() << "for (" << var->dtype().ToCppString() << " " << vv << " = " + << ExprHandle(v->start()) << "; " << vv << " < " << ExprHandle(v->stop()) + << "; " << vv << "++) {"; + std::string loop_options_str = v->loop_options().ToString(); + if (!loop_options_str.empty()) { + os() << " // " << loop_options_str; + } + os() << std::endl; + if (v->body()) { + indent_++; + os() << *v->body() << std::endl; + indent_--; + } + emitIndent(); + os() << "}"; +} + +void IRPrinter::visit(const Block* v) { + for (int i = 0; i < v->nstmts(); ++i) { + os() << *v->stmt(i) << std::endl; + } +} + +void IRPrinter::visit(const Store* v) { + // TODO: handle the mask + emitIndent(); + os() << *v->base_handle() << "[" << *v->index() << "] = " << *v->value() + << ";"; +} + +void IRPrinter::visit(const OpaqueCall* v) { + os() << *v->output_handle() << " = " << v->name() << "("; + for (auto& ih : v->input_handles()) { + os() << *ih; + if (&ih != &v->input_handles().back()) { + os() << ", "; + } + } + os() << ")"; +} + +void IRPrinter::visit(const Broadcast* v) { + os() << "Broadcast(" << v->value() << ", " << v->lanes() << ")"; +} + +void IRPrinter::visit(const IfThenElse* v) { + os() << "IfThenElse(" << *v->condition() << ", " << *v->true_value() << ", " + << *v->false_value() << ")"; +} + +void IRPrinter::visit(const BaseCallNode* v) { + os() << v->func_name() << "("; + for (int i = 0; i < v->nparams(); i++) { + if (i > 0) { + os() << ", "; + } + os() << *v->param(i); + } + os() << ")"; +} + +void IRPrinter::visit(const CallExternal* v) { + os() << v->name() << "("; + for (auto p : v->params()) { + os() << p; + if (&p != &v->params().back()) { + os() << ", "; + } + } + os() << ")"; +} + +void IRPrinter::visit(const Allocate* v) { + emitIndent(); + os() << "Allocate(" << *v->buffer_var() << ", " << v->dtype(); + os() << ", {"; + const std::vector& dims = v->dims(); + for (size_t i = 0; i < dims.size(); i++) { + if (i != 0) { + os() << ", "; + } + os() << *dims[i]; + } + os() << "});"; +} + +void IRPrinter::visit(const Free* v) { + emitIndent(); + os() << "Free(" << *v->buffer_var() << ");"; +} + +void IRPrinter::visit(const Cond* v) { + const Expr* cond = v->condition(); + Stmt* true_stmt = v->true_stmt(); + Stmt* false_stmt = v->false_stmt(); + if (!true_stmt) { + emitIndent(); + os() << "if (!" << *cond << ") {" << std::endl; + indent_++; + os() << *false_stmt << std::endl; + indent_--; + emitIndent(); + os() << "}"; + } else { + emitIndent(); + os() << "if (" << *cond << ") {" << std::endl; + indent_++; + os() << *true_stmt << std::endl; + indent_--; + emitIndent(); + os() << "}"; + if (false_stmt) { + os() << " else {" << std::endl; + indent_++; + os() << *false_stmt << std::endl; + indent_--; + emitIndent(); + os() << "}"; + } + } +} + +void IRPrinter::emitIndent() { + os() << std::setw(2 * indent_) << ""; +} + +std::ostream& operator<<(std::ostream& stream, const ExprHandle& expr) { + IRPrinter::PrinterStream* printer_stream = + dynamic_cast(&stream); + if (printer_stream != nullptr) { + expr.node()->accept(printer_stream->printer()); + } else { + IRPrinter p(stream); + p.print(expr); + } + return stream; +} + +std::ostream& operator<<(std::ostream& stream, const Expr& expr) { + IRPrinter::PrinterStream* printer_stream = + dynamic_cast(&stream); + if (printer_stream != nullptr) { + expr.accept(printer_stream->printer()); + } else { + IRPrinter p(stream); + p.print(expr); + } + return stream; +} + +std::ostream& operator<<(std::ostream& stream, const Stmt& stmt) { + IRPrinter::PrinterStream* printer_stream = + dynamic_cast(&stream); + if (printer_stream != nullptr) { + stmt.accept(printer_stream->printer()); + } else { + IRPrinter p(stream); + p.print(stmt); + } + return stream; +} + +std::ostream& operator<<(std::ostream& stream, Stmt* stmt) { + IRPrinter::PrinterStream* printer_stream = + dynamic_cast(&stream); + if (printer_stream != nullptr) { + stmt->accept(printer_stream->printer()); + } else { + IRPrinter p(stream); + p.print(*stmt); + } + return stream; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_printer.h b/torch/csrc/jit/tensorexpr/ir_printer.h new file mode 100644 index 0000000000000..6e04c0e9d56cb --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_printer.h @@ -0,0 +1,108 @@ +#pragma once + +#include + +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/unique_name_manager.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +class TORCH_API IRPrinter : public IRVisitor { + public: + explicit IRPrinter(std::ostream& os) : printer_os_(this, os) {} + + void print(ExprHandle); + void print(const Expr&); + void print(const Stmt&); + void visit(const Add* v) override; + void visit(const Sub* v) override; + void visit(const Mul* v) override; + void visit(const Div* v) override; + void visit(const Mod* v) override; + void visit(const Max* v) override; + void visit(const Min* v) override; + void visit(const And* v) override; + void visit(const Xor* v) override; + void visit(const Lshift* v) override; + void visit(const Rshift* v) override; + void visit(const CompareSelect* v) override; +#define IMM_PRINT_VISIT(Type, Name) void visit(const Name##Imm* v) override; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT); +#undef IMM_PRINT_VISIT + void visit(const Cast* v) override; + void visit(const Var* v) override; + void visit(const Let* v) override; + void visit(const LetStmt* v) override; + void visit(const Ramp* v) override; + void visit(const Load* v) override; + void visit(const For* v) override; + void visit(const Block* v) override; + void visit(const Store* v) override; + void visit(const OpaqueCall* v) override; + void visit(const Broadcast* v) override; + void visit(const IfThenElse* v) override; + void visit(const BaseCallNode* v) override; + void visit(const CallExternal* v) override; + void visit(const Allocate* v) override; + void visit(const Free* v) override; + void visit(const Cond* v) override; + + std::ostream& os() { + return printer_os_; + } + + class PrinterStream : public std::ostream { + public: + PrinterStream(IRPrinter* printer, std::ostream& os) + : std::ostream(os.rdbuf()), printer_(printer) {} + + IRPrinter* printer() { + return printer_; + } + + private: + IRPrinter* printer_ = nullptr; + }; + + protected: + UniqueNameManager* name_manager() { + return &name_manager_; + } + + private: + void emitIndent(); + int indent_ = 0; + PrinterStream printer_os_; + UniqueNameManager name_manager_; +}; + +TORCH_API std::ostream& operator<<(std::ostream& stream, const Expr&); +TORCH_API std::ostream& operator<<(std::ostream& stream, const ExprHandle&); +TORCH_API std::ostream& operator<<(std::ostream& stream, const Stmt&); +TORCH_API std::ostream& operator<<(std::ostream& stream, Stmt*); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +namespace std { + +using torch::jit::tensorexpr::ExprHandle; +using torch::jit::tensorexpr::Stmt; + +inline std::string to_string(const ExprHandle& expr) { + std::ostringstream oss; + oss << expr; + return oss.str(); +} + +inline std::string to_string(Stmt* stmt) { + std::ostringstream oss; + oss << stmt; + return oss.str(); +} + +}; // namespace std diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.cpp b/torch/csrc/jit/tensorexpr/ir_visitor.cpp new file mode 100644 index 0000000000000..27910a23fee3d --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_visitor.cpp @@ -0,0 +1,191 @@ +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" + +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +template +static void visit_binary_op(const BinaryOpNode* v, IRVisitor* visitor) { + v->lhs()->accept(visitor); + v->rhs()->accept(visitor); +} + +void IRVisitor::visit(const Add* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Sub* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Mul* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Div* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Mod* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Max* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Min* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const And* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Xor* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Lshift* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const Rshift* v) { + visit_binary_op(v, this); +} + +void IRVisitor::visit(const CompareSelect* v) { + v->lhs()->accept(this); + v->rhs()->accept(this); + v->ret_val1()->accept(this); + v->ret_val2()->accept(this); +} + +#define IMM_VISIT(Type, Name) \ + void IRVisitor::visit(const Name##Imm* v) {} +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT); +#undef IMM_VISIT + +void IRVisitor::visit(const Cast* v) { + v->src_value()->accept(this); +} +void IRVisitor::visit(const Var* v) {} +void IRVisitor::visit(const Let* v) { + v->var()->accept(this); + v->value()->accept(this); + v->body()->accept(this); +} + +void IRVisitor::visit(const LetStmt* v) { + v->var()->accept(this); + v->value()->accept(this); + v->body()->accept(this); +} + +void IRVisitor::visit(const Ramp* v) { + v->base()->accept(this); + v->stride()->accept(this); +} + +void IRVisitor::visit(const Load* v) { + v->base_handle()->accept(this); + v->index()->accept(this); + v->mask()->accept(this); +} + +void IRVisitor::visit(const Store* v) { + v->base_handle()->accept(this); + v->index()->accept(this); + v->value()->accept(this); + v->mask()->accept(this); +} + +void IRVisitor::visit(const OpaqueCall* v) { + v->output_handle()->accept(this); + for (auto& ih : v->input_handles()) { + ih->accept(this); + } + for (auto& a : v->arguments()) { + a->accept(this); + } +} + +void IRVisitor::visit(const Block* v) { + for (int i = 0; i < v->nstmts(); i++) { + v->stmt(i)->accept(this); + } +} + +void IRVisitor::visit(const For* v) { + v->var()->accept(this); + v->start()->accept(this); + v->stop()->accept(this); + if (v->body()) { + v->body()->accept(this); + } +} + +void IRVisitor::visit(const Broadcast* v) { + v->value()->accept(this); +} + +void IRVisitor::visit(const IfThenElse* v) { + v->condition()->accept(this); + v->true_value()->accept(this); + v->false_value()->accept(this); +} + +void IRVisitor::visit(const BaseCallNode* v) { + for (int i = 0; i < v->nparams(); i++) { + v->param(i)->accept(this); + } +} + +void IRVisitor::visit(const Intrinsics* v) { + const BaseCallNode* base = v; + this->visit(base); +} + +void IRVisitor::visit(const CallExternal* v) { + const BaseCallNode* base = v; + this->visit(base); +} + +void IRVisitor::visit(const FunctionCall* v) { + const BaseCallNode* base = v; + this->visit(base); +} + +void IRVisitor::visit(const Allocate* v) { + const Var* buffer_var = v->buffer_var(); + buffer_var->accept(this); + std::vector dims = v->dims(); + for (const Expr* dim : dims) { + dim->accept(this); + } +} + +void IRVisitor::visit(const Free* v) { + const Var* buffer_var = v->buffer_var(); + buffer_var->accept(this); +} + +void IRVisitor::visit(const Cond* v) { + const Expr* condition = v->condition(); + Stmt* true_stmt = v->true_stmt(); + Stmt* false_stmt = v->false_stmt(); + condition->accept(this); + if (true_stmt) { + true_stmt->accept(this); + } + if (false_stmt) { + false_stmt->accept(this); + } +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/ir_visitor.h b/torch/csrc/jit/tensorexpr/ir_visitor.h new file mode 100644 index 0000000000000..523ae95f35156 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/ir_visitor.h @@ -0,0 +1,98 @@ +#pragma once +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +class Add; +class Sub; +class Mul; +class Div; +class Mod; +class Max; +class Min; +class And; +class Xor; +class Lshift; +class Rshift; +class CompareSelect; + +#define IMM_DECLARE(Type, Name) class Name##Imm; + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_DECLARE) +#undef IMM_DECLARE + +class Cast; +class Var; +class Let; +class LetStmt; +class Ramp; +class Load; +class For; +class Block; +class Store; +class OpaqueCall; +class Broadcast; +class IfThenElse; +class BaseCallNode; +class Intrinsics; +class CallExternal; +class FunctionCall; +class Allocate; +class Free; +class Cond; + +class TORCH_API IRVisitor { + public: + virtual ~IRVisitor() {} + virtual void visit(const Add* v); + virtual void visit(const Sub* v); + virtual void visit(const Mul* v); + virtual void visit(const Div* v); + virtual void visit(const Mod* v); + virtual void visit(const Max* v); + virtual void visit(const Min* v); + virtual void visit(const And* v); + virtual void visit(const Xor* v); + virtual void visit(const Lshift* v); + virtual void visit(const Rshift* v); + virtual void visit(const CompareSelect* v); + +#define IMM_PRINT_VISIT(Type, Name) virtual void visit(const Name##Imm* v); + + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_PRINT_VISIT) +#undef IMM_PRINT_VISIT + + virtual void visit(const Cast* v); + virtual void visit(const Var* v); + virtual void visit(const Let* v); + virtual void visit(const LetStmt* v); + virtual void visit(const Ramp* v); + virtual void visit(const Load* v); + virtual void visit(const For* v); + virtual void visit(const Block* v); + virtual void visit(const Store* v); + virtual void visit(const OpaqueCall* v); + virtual void visit(const Broadcast* v); + virtual void visit(const IfThenElse* v); + + // BaseCallNode is the base class for all call nodes. + // For any visitors that only needs the common behavior, only override this + // function is enough. This is because all derived class handlers will call + // this function by default. + // Override the derived class handler only if the logic is more specific to + // that. + virtual void visit(const BaseCallNode* v); + virtual void visit(const Intrinsics* v); + virtual void visit(const CallExternal* v); + virtual void visit(const FunctionCall* v); + virtual void visit(const Allocate* v); + virtual void visit(const Free* v); + virtual void visit(const Cond* v); +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp new file mode 100644 index 0000000000000..6e1d1eb1855f6 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -0,0 +1,1309 @@ +#include +#include +#include +#include "torch/csrc/jit/tensorexpr/native.h" + +using namespace torch::jit; +using namespace torch::jit::tensorexpr; + +namespace torch { +namespace jit { +namespace tensorexpr { + +static int te_cuda_pointwise_loop_levels = -1; +static int te_cuda_pointwise_block_count = -1; +static int te_cuda_pointwise_block_size = -1; + +int& GetTECudaPointwiseLoopLevels() { + return te_cuda_pointwise_loop_levels; +} + +int& GetTECudaPointwiseBlockCount() { + return te_cuda_pointwise_block_count; +} + +int& GetTECudaPointwiseBlockSize() { + return te_cuda_pointwise_block_size; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +static at::ScalarType tensorType(Tensor* t) { + return static_cast(t->body()->dtype().scalar_type()); +} + +static std::vector texprSizes(const c10::VaryingShape& shape) { + std::vector dims; + for (size_t i = 0; i < *shape.size(); i++) { + dims.push_back(IntImm::make(*shape[i])); + } + return dims; +} + +namespace torch { +namespace jit { +namespace tensorexpr { + +std::vector texprDims(const torch::jit::Value* v) { + CHECK(v->type()->kind() == TypeKind::TensorType); + auto tt = v->type()->cast(); + std::vector dimArgs; + int i = 0; + for (auto const& s : texprSizes(tt->sizes())) { + dimArgs.push_back({s, "i" + std::to_string(i++)}); + } + return dimArgs; +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +template +int64_t bufferSize(T t) { + int64_t size = 1; + for (int i = 0; i < t.ndim(); i++) { + size *= t.dim(i).template AsNode()->value(); + } + return size; +} + +ExprHandle TensorExprKernel::constant(const torch::jit::Value* v) { + if (v->node()->kind() == prim::Constant) { + const auto val = toIValue(v).value(); + if (val.isDouble()) { + return FloatImm::make(val.toDouble()); + } else if (val.isInt()) { + return IntImm::make(val.toInt()); + } else if (val.isNone()) { + // This is just a placeholder so we don't throw. None-handling + // is operator-specific and should be handled properly in + // the operator-specific lowering code. + return IntImm::make(0); + } else { + LOG(FATAL) << "Unhandled constant datatype"; + } + } + CHECK(scalars_.count(v->unique())) << "Couldn't find scalar value"; + return scalars_.at(v->unique()); +} + +void TensorExprKernel::promoteInputs(std::vector& inputs) { + if (inputs.empty()) { + return; + } + + // Find the highest type among the inputs. + ScalarType highType = inputs[0].dtype().scalar_type(); + for (int i = 0; i < inputs.size(); ++i) { + ScalarType iType = inputs[i].dtype().scalar_type(); + if (iType == ScalarType::Bool) { + continue; + } + highType = promoteTypes(highType, iType); + } + + for (ExprHandle& e : inputs) { + if (e.dtype().scalar_type() == ScalarType::Bool) { + continue; + } + + if (e.dtype().scalar_type() == highType) { + continue; + } + + switch (highType) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + e = cast(e); \ + break; + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unsupported datatype: " << highType; + } + } +} + +ExprHandle TensorExprKernel::demoteOutput( + const ExprHandle& e, + const torch::jit::Value* v) { + CHECK(v->type()->kind() == TypeKind::TensorType); + auto tt = *v->type()->cast()->scalarType(); + + if (tt == static_cast(e.dtype().scalar_type())) { + return e; + } + + switch (tt) { +#define TYPE_CASE(Type, Name) \ + case at::ScalarType::Name: \ + return cast(e); + AT_FORALL_SCALAR_TYPES_AND(Half, TYPE_CASE); +#undef TYPE_CASE + case at::ScalarType::Bool: + return e; + default: + LOG(FATAL) << "Unsupported datatype"; + } + + return e; +} + +static bool isOne(ExprHandle e) { + auto const& n = e.AsNode(); + if (!n) { + return false; + } + return n->value() == 1; +} + +static std::vector broadcastShapes( + const std::vector& a, + const std::vector& b) { + auto at = a.rbegin(); + auto bt = b.rbegin(); + std::vector ret; + while (at != a.rend() || bt != b.rend()) { + if (at == a.rend()) { + ret.push_back(*bt++); + continue; + } + if (bt == b.rend()) { + ret.push_back(*at++); + continue; + } + // TODO: if neither *at nor *bt is 1, ensure they are identical + // expressions. Nb: `==` doesn't work since that simply produces a new + // ExprHandle. + ExprHandle dim = isOne(*at) ? *bt : *at; + ret.push_back(dim); + at++; + bt++; + } + std::reverse(ret.begin(), ret.end()); + return ret; +} + +template +static std::vector broadcastShapes( + const std::vector& a, + const std::vector& b, + Args... args) { + return broadcastShapes(broadcastShapes(a, b), args...); +} + +std::vector TensorExprKernel::valueShape( + const torch::jit::Value* v) { + auto it = tensors_.find(v->unique()); + if (it == tensors_.end()) { + return {1}; + } + return ExprVectorToExprHandleVector(it->second->dims()); +} + +Tensor* TensorExprKernel::ComputeOneOperand( + const std::string& name, + const torch::jit::Value* v, + std::function inner_expr) { + auto const& n = v->node(); + auto const& shape = valueShape(n->inputs()[0]); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], axes)}; + + promoteInputs(inputs); + ExprHandle compute = inner_expr(inputs[0]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeTwoOperand( + const std::string& name, + const torch::jit::Value* v, + std::function + inner_expr) { + auto const& n = v->node(); + auto const& shape = + broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], axes), + tensorOrConstant(n->inputs()[1], axes), + }; + + promoteInputs(inputs); + ExprHandle compute = inner_expr(inputs[0], inputs[1]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeTwoOperandWithAlpha( + const std::string& name, + const torch::jit::Value* v, + std::function + inner_expr) { + auto const& n = v->node(); + auto const& shape = + broadcastShapes(valueShape(n->inputs()[0]), valueShape(n->inputs()[1])); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], axes), + tensorOrConstant(n->inputs()[1], axes), + tensorOrConstant(n->inputs()[2], axes), + }; + + promoteInputs(inputs); + ExprHandle compute = inner_expr(inputs[0], inputs[2] * inputs[1]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeConditionWithTwoOperand( + const std::string& name, + const torch::jit::Value* v, + std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)> + inner_expr) { + auto const& n = v->node(); + auto const& shape = broadcastShapes( + valueShape(n->inputs()[0]), + valueShape(n->inputs()[1]), + valueShape(n->inputs()[2])); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[1], axes), + tensorOrConstant(n->inputs()[2], axes), + }; + + promoteInputs(inputs); + // First expr is the condition, which we don't promote + inputs.emplace(inputs.begin(), tensorOrConstant(n->inputs()[0], axes)); + ExprHandle compute = inner_expr(inputs[0], inputs[1], inputs[2]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeThreeOperand( + const std::string& name, + const torch::jit::Value* v, + std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)> + inner_expr) { + auto const& n = v->node(); + auto const& shape = broadcastShapes( + valueShape(n->inputs()[0]), + valueShape(n->inputs()[1]), + valueShape(n->inputs()[2])); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], axes), + tensorOrConstant(n->inputs()[1], axes), + tensorOrConstant(n->inputs()[2], axes), + }; + + promoteInputs(inputs); + ExprHandle compute = inner_expr(inputs[0], inputs[1], inputs[2]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeFourOperand( + const std::string& name, + const torch::jit::Value* v, + std::function inner_expr) { + auto const& n = v->node(); + auto const& shape = broadcastShapes( + valueShape(n->inputs()[0]), + valueShape(n->inputs()[1]), + valueShape(n->inputs()[2]), + valueShape(n->inputs()[3])); + return Compute( + name, + c10::fmap(shape), + [this, v, inner_expr](const std::vector& axes) { + auto const& n = v->node(); + std::vector inputs = { + tensorOrConstant(n->inputs()[0], axes), + tensorOrConstant(n->inputs()[1], axes), + tensorOrConstant(n->inputs()[2], axes), + tensorOrConstant(n->inputs()[3], axes), + }; + + promoteInputs(inputs); + ExprHandle compute = + inner_expr(inputs[0], inputs[1], inputs[2], inputs[3]); + return demoteOutput(compute, n->output()); + }); +} + +Tensor* TensorExprKernel::ComputeValue(const torch::jit::Value* v) { + switch (v->node()->kind()) { + case aten::add: { + return ComputeTwoOperandWithAlpha( + "aten_add", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs + rhs; + }); + } break; + + case aten::_cast_Float: { + return ComputeOneOperand("aten_cast_float", v, [](const ExprHandle& a) { + return cast(a); + }); + } break; + + case aten::sub: { + return ComputeTwoOperandWithAlpha( + "aten_sub", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs - rhs; + }); + } break; + + case aten::mul: { + return ComputeTwoOperand( + "aten_mul", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs * rhs; + }); + } break; + + case aten::div: { + return ComputeTwoOperand( + "aten_div", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs / rhs; + }); + } break; + + case aten::__and__: { + return ComputeTwoOperand( + "aten_and", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs & rhs; + }); + } break; + + case aten::__xor__: { + return ComputeTwoOperand( + "aten_xor", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs ^ rhs; + }); + } break; + + case aten::__lshift__: { + return ComputeTwoOperand( + "aten_lshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs << rhs; + }); + } break; + + case aten::__rshift__: { + return ComputeTwoOperand( + "aten_rshift", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs >> rhs; + }); + } break; + + case aten::addcmul: { + return ComputeFourOperand( + "aten_addcmul", + v, + [](const ExprHandle& a0, + const ExprHandle& a1, + const ExprHandle& a2, + const ExprHandle& a3) { return a0 + a3 * a1 * a2; }); + } break; + + case aten::eq: { + return ComputeTwoOperand( + "aten_eq", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs == rhs; + }); + } break; + + case aten::ne: { + return ComputeTwoOperand( + "aten_ne", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs != rhs; + }); + } break; + case aten::ge: { + return ComputeTwoOperand( + "aten_ge", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs >= rhs; + }); + } break; + + case aten::gt: { + return ComputeTwoOperand( + "aten_gt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs > rhs; + }); + } break; + + case aten::le: { + return ComputeTwoOperand( + "aten_le", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs <= rhs; + }); + } break; + + case aten::lt: { + return ComputeTwoOperand( + "aten_lt", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs < rhs; + }); + } break; + + case aten::min: { + return ComputeTwoOperand( + "aten_min", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return Min::make(lhs, rhs, false); + }); + } break; + + case aten::max: { + return ComputeTwoOperand( + "aten_max", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return Max::make(lhs, rhs, false); + }); + } break; + + case aten::clamp: { + bool no_min = false; + bool no_max = false; + if (v->node()->input(1)->node()->kind() == prim::Constant) { + const auto val = toIValue(v->node()->input(1)).value(); + if (val.isNone()) { + no_min = true; + } + } + + if (v->node()->input(2)->node()->kind() == prim::Constant) { + const auto val = toIValue(v->node()->input(2)).value(); + if (val.isNone()) { + no_max = true; + } + } + + return ComputeThreeOperand( + "aten_clamp", + v, + [no_min, no_max]( + const ExprHandle& in, + const ExprHandle& min, + const ExprHandle& max) { + if (no_min && no_max) { + return in; + } else if (no_min) { + return Min::make(in, max, false); + } else if (no_max) { + return Max::make(in, min, false); + } else { + return Max::make(Min::make(in, max, false), min, false); + } + }); + } break; + + case aten::sigmoid: { + return ComputeOneOperand("aten_sigmoid", v, [](const ExprHandle& a) { + return ExprHandle(1.0f) / + (ExprHandle(1.0f) + exp(ExprHandle(-0.0f) - a)); + }); + } break; + + case aten::reciprocal: { + return ComputeOneOperand("aten_reciprocal", v, [](const ExprHandle& a) { + return ExprHandle(1.0f) / a; + }); + } break; + + case aten::neg: { + return ComputeOneOperand("aten_neg", v, [](const ExprHandle& a) { + return ExprHandle(-0) - a; + }); + } break; + + case aten::relu: { + return ComputeOneOperand("aten_relu", v, [](const ExprHandle& a) { + return Max::make(a, 0, false); + }); + } break; + + case aten::log: { + return ComputeOneOperand( + "aten_log", v, [](const ExprHandle& a) { return log(a); }); + } break; + + case aten::log10: { + return ComputeOneOperand( + "aten_log10", v, [](const ExprHandle& a) { return log10(a); }); + } break; + + case aten::log2: { + return ComputeOneOperand( + "aten_log2", v, [](const ExprHandle& a) { return log2(a); }); + } break; + + case aten::exp: { + return ComputeOneOperand( + "aten_exp", v, [](const ExprHandle& a) { return exp(a); }); + } break; + + case aten::expm1: { + return ComputeOneOperand( + "aten_expm1", v, [](const ExprHandle& a) { return expm1(a); }); + } break; + + case aten::erf: { + return ComputeOneOperand( + "aten_erf", v, [](const ExprHandle& a) { return erf(a); }); + } break; + + case aten::erfc: { + return ComputeOneOperand( + "aten_erfc", v, [](const ExprHandle& a) { return erfc(a); }); + } break; + + case aten::cos: { + return ComputeOneOperand( + "aten_cos", v, [](const ExprHandle& a) { return cos(a); }); + } break; + + case aten::sin: { + return ComputeOneOperand( + "aten_sin", v, [](const ExprHandle& a) { return sin(a); }); + } break; + + case aten::tan: { + return ComputeOneOperand( + "aten_tan", v, [](const ExprHandle& a) { return tan(a); }); + } break; + + case aten::type_as: { + return ComputeTwoOperand( + "aten_type_as", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return Cast::make(rhs.dtype(), lhs); + }); + } break; + + case aten::rand_like: { + return ComputeOneOperand("aten_rand_like", v, [](const ExprHandle& a) { + return Intrinsics::make(IntrinsicsOp::kRand, a.dtype()); + }); + } break; + + case aten::pow: { + return ComputeTwoOperand( + "aten_pow", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + const FloatImm* float_imm = rhs.AsNode(); + if (float_imm) { + float imm = float_imm->value(); + if (imm == 1.0f) { + return lhs; + } else if (imm == 2.0f) { + return lhs * lhs; + } else if (imm == 3.0f) { + return (lhs * lhs) * lhs; + } else if (imm == 4.0f) { + ExprHandle tmp = lhs * lhs; + return tmp * tmp; + } else if (imm == 0.5f) { + return sqrt(lhs); + } else if (imm == 0.0f) { + return ExprHandle(1.0f); + } else if (imm == -0.5f) { + return rsqrt(lhs); + } else if (imm == -1.0f) { + return ExprHandle(1.0f) / lhs; + } else if (imm == -2.0f) { + return ExprHandle(1.0f) / (lhs * lhs); + } + } + + const Cast* float_cast = rhs.AsNode(); + if (float_cast) { + const IntImm* int_imm = + dynamic_cast(float_cast->src_value()); + if (int_imm) { + float imm = int_imm->value(); + if (imm == 1) { + return lhs; + } else if (imm == 2) { + return lhs * lhs; + } else if (imm == 3) { + return (lhs * lhs) * lhs; + } else if (imm == 4) { + ExprHandle tmp = lhs * lhs; + return tmp * tmp; + } else if (imm == 0) { + return ExprHandle(1.0f); + } else if (imm == -1) { + return ExprHandle(1.0f) / lhs; + } else if (imm == -2) { + return ExprHandle(1.0f) / (lhs * lhs); + } + } + } + return pow(lhs, rhs); + }); + } break; + + case aten::fmod: { + return ComputeTwoOperand( + "aten_fmod", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return fmod(lhs, rhs); + }); + } break; + + case aten::lerp: { + return ComputeThreeOperand( + "aten_lerp", + v, + [](const ExprHandle& a, + const ExprHandle& end, + const ExprHandle& weight) { return a + weight * (end - a); }); + } break; + case aten::remainder: { + return ComputeTwoOperand( + "aten_remainder", + v, + [](const ExprHandle& lhs, const ExprHandle& rhs) { + return fmod((rhs + fmod(lhs, rhs)), rhs); + }); + + } break; + + case aten::acos: { + return ComputeOneOperand( + "aten_acos", v, [](const ExprHandle& a) { return acos(a); }); + } break; + + case aten::asin: { + return ComputeOneOperand( + "aten_asin", v, [](const ExprHandle& a) { return asin(a); }); + } break; + + case aten::cosh: { + return ComputeOneOperand( + "aten_cosh", v, [](const ExprHandle& a) { return cosh(a); }); + } break; + + case aten::sinh: { + return ComputeOneOperand( + "aten_sinh", v, [](const ExprHandle& a) { return sinh(a); }); + } break; + + case aten::atan: { + return ComputeOneOperand( + "aten_atan", v, [](const ExprHandle& a) { return atan(a); }); + } break; + + case aten::atan2: { + return ComputeTwoOperand( + "aten_atan2", v, [](const ExprHandle& lhs, const ExprHandle& rhs) { + return atan2(lhs, rhs); + }); + } break; + + case aten::tanh: { + return ComputeOneOperand("aten_tanh", v, [](const ExprHandle& a) { + // return + // (ExprHandle(-.67436811832e-5f)+(ExprHandle(.2468149110712040f)+(ExprHandle(.583691066395175e-1f)+ExprHandle(.3357335044280075e-1f)*a)*a)*a)/(ExprHandle(.2464845986383725f)+(ExprHandle(.609347197060491e-1f)+(ExprHandle(.1086202599228572f)+ExprHandle(.2874707922475963e-1f)*a)*a)*a); + return tanh(a); + }); + } break; + + case aten::sqrt: { + return ComputeOneOperand( + "aten_sqrt", v, [](const ExprHandle& a) { return sqrt(a); }); + } break; + + case aten::rsqrt: { + return ComputeOneOperand( + "aten_rsqrt", v, [](const ExprHandle& a) { return rsqrt(a); }); + } break; + + case aten::abs: { + return ComputeOneOperand( + "aten_abs", v, [](const ExprHandle& a) { return fabs(a); }); + } break; + + case aten::ceil: { + return ComputeOneOperand( + "aten_ceil", v, [](const ExprHandle& a) { return ceil(a); }); + } break; + + case aten::floor: { + return ComputeOneOperand( + "aten_floor", v, [](const ExprHandle& a) { return floor(a); }); + } break; + + case aten::round: { + return ComputeOneOperand( + "aten_round", v, [](const ExprHandle& a) { return round(a); }); + } break; + + case aten::trunc: { + return ComputeOneOperand( + "aten_trunc", v, [](const ExprHandle& a) { return trunc(a); }); + } break; + + case aten::threshold: { + return ComputeThreeOperand( + "aten_threshold", + v, + [](const ExprHandle& a, + const ExprHandle& threshold, + const ExprHandle& value) { + return ifThenElse(CompareSelect::make(a, threshold, kGT), a, value); + }); + } break; + + case aten::where: { + return ComputeConditionWithTwoOperand( + "aten_where", + v, + [](const ExprHandle& a0, const ExprHandle& a1, const ExprHandle& a2) { + return ifThenElse(a0, a1, a2); + }); + } break; + + case aten::frac: { + return ComputeOneOperand( + "aten_frac", v, [](const ExprHandle& a) { return a - floor(a); }); + } break; + + case aten::lgamma: { + return ComputeOneOperand( + "aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); }); + } break; + + case prim::ConstantChunk: { + return Compute( + "prim_constantchunk", + texprDims(v), + [this, v](const std::vector& axes) { + auto const& n = v->node(); + int64_t dim = n->i(attr::dim); + int64_t chunks = n->i(attr::chunks); + return chunk( + tensors_.at(n->inputs()[0]->unique()), + v->offset(), + dim, + chunks, + axes); + }); + } + + case aten::cat: { + return Compute( + "aten_cat", + texprDims(v), + [this, v](const std::vector& axes) { + auto const& n = v->node(); + auto inputs = n->inputs()[0]->node()->inputs(); + size_t dim = n->inputs()[1]->node()->i(attr::value); + + std::vector new_axes(axes.begin(), axes.end()); + ExprHandle load = tensorOrConstant(inputs[0], new_axes); + size_t offset = bufferSizes(tensors_.at(inputs[0]->unique()))[dim]; + new_axes[dim] = new_axes[dim] - IntImm::make(offset); + + for (int ii = 1; ii < inputs.size(); ++ii) { + load = ifThenElse( + CompareSelect::make(axes[dim], IntImm::make(offset), kLT), + load, + tensorOrConstant(inputs[ii], new_axes)); + offset += bufferSizes(tensors_.at(inputs[ii]->unique()))[dim]; + new_axes[dim] = new_axes[dim] - IntImm::make(offset); + } + + return load; + }); + } + + case aten::slice: { + return Compute( + "aten_slice", + texprDims(v), + [this, v](const std::vector& axes) { + auto const& n = v->node(); + int dim = constant(n->inputs()[1]).AsNode()->value(); + ExprHandle start = constant(n->inputs()[2]); + ExprHandle stride = constant(n->inputs()[4]); + + std::vector new_axes(axes.begin(), axes.end()); + new_axes[dim] = stride * new_axes[dim] + start; + return tensorOrConstant(n->inputs()[0], new_axes); + }); + } + + case aten::unsqueeze: { + return Compute( + "aten_unsqueeze", + texprDims(v), + [this, v](const std::vector& axes) { + auto const& n = v->node(); + int dim = constant(n->inputs()[1]).AsNode()->value(); + if (dim < 0) { + dim += axes.size() - 1; + } + + std::vector new_axes(axes.begin(), axes.end()); + new_axes.erase(new_axes.begin() + dim); + return tensorOrConstant(n->inputs()[0], new_axes); + }); + } + + case aten::_sigmoid_backward: { + return ComputeTwoOperand( + "aten_sigmoid_backward", + v, + [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs * rhs * (ExprHandle(1.0f) - rhs); + }); + } + + case aten::_tanh_backward: { + return ComputeTwoOperand( + "aten_tanh_backward", + v, + [](const ExprHandle& lhs, const ExprHandle& rhs) { + return lhs * (ExprHandle(1.0f) - rhs * rhs); + }); + } + + default: { + auto& nfr = getNativeFunctionRegistry(); + auto qs = v->node()->kind().toQualString(); + if (nfr.count(qs)) { + return nfr.at(qs).second(this, v); + } + throw std::runtime_error("Unhandled node kind"); + } + } +} + +void TensorExprKernel::addNoInline(int64_t unique_id) { + no_inline_.insert(unique_id); +} + +void TensorExprKernel::LowerToBackend(BackendType backend_type) { + std::vector tensor_outputs(tensor_outputs_); + + if (backend_type == BackendType::kCudaCodeGen) { + for (int i = 0; i < tensor_outputs_.size(); i++) { + Tensor* tensor = tensor_outputs_[i]; + ExprHandle total_count = ExprHandle(tensor->dim(0)); + for (int i = 1; i < tensor->ndim(); i++) { + total_count = total_count * ExprHandle(tensor->dim(i)); + } + // Flatten the index for GPU kernels. + // TODO: move this to fusing axis when it is ready. + Tensor* new_out = Compute( + tensor->func_var()->name_hint() + "_flat", + {total_count}, + [tensor](const VarHandle& index) -> ExprHandle { + std::vector dims; + ExprHandle value = index; + for (int i = tensor->ndim() - 1; i >= 0; i--) { + ExprHandle idx = value; + if (i > 0) { + idx = Mod::make(value, ExprHandle(tensor->dim(i))); + } + dims.push_back(idx); + value = value / ExprHandle(tensor->dim(i)); + } + std::reverse(dims.begin(), dims.end()); + return tensor->call(dims); + }); + tensor_outputs[i] = new_out; + } + } + + torch::jit::tensorexpr::schedule::Schedule sch(tensor_outputs); + + for (auto& p : tensors_) { + if (dynamic_cast(p.second->body()) != nullptr) { + addNoInline(p.first); + } + } + + // Compute non-output tensors_ inline + for (auto& p : tensors_) { + if (no_inline_.find(p.first) != no_inline_.end()) { + continue; + } + p.second->ComputeInline(); + } + + if (backend_type == kCudaCodeGen) { + for (int i = 0; i < tensor_outputs_.size(); i++) { + // TODO: audit this logic in the presence of external calls + tensor_outputs_[i]->ComputeInline(); + + Tensor* tensor = tensor_outputs[i]; + const Var* index = tensor->arg(0); + int loop_levels = GetTECudaPointwiseLoopLevels(); + const int kDefaultLoopLevels = 2; + loop_levels = (loop_levels > 0) ? loop_levels : kDefaultLoopLevels; + int block_count = GetTECudaPointwiseBlockCount(); + int block_size = GetTECudaPointwiseBlockSize(); + + if (loop_levels == 2) { + VarHandle outer; + VarHandle inner; + int kDefaultBlockSize = 512; + if (block_size < 0) { + block_size = kDefaultBlockSize; + } + tensor->SplitWithMask( + VarHandle(index), block_size, true, &outer, &inner); + tensor->GPUExecConfig({outer}, {inner}); + } else if (loop_levels == 3) { + VarHandle outer; + VarHandle inner; + VarHandle inner_1; + VarHandle inner_2; + // TODO: change the number of microprocessors + const int kDefaultBlockCount = 1280; + const int kDefaultBlockSize = 256; + block_count = (block_count > 0) ? block_count : kDefaultBlockCount; + block_size = (block_size > 0) ? block_size : kDefaultBlockSize; + tensor->SplitWithMask( + VarHandle(index), block_count * block_size, true, &outer, &inner); + tensor->SplitWithMask(inner, block_size, true, &inner_1, &inner_2); + tensor->GPUExecConfig({inner_1}, {inner_2}); + } else { + throw std::runtime_error( + "Invalid loop-level: " + std::to_string(loop_levels)); + } + } + } + + Stmt* stmt = sch.Lower(); + + // Set up formal params (inputs, then outputs) for kernel. + std::vector params; + for (auto const& arg : kernelArgs_) { + params.push_back(arg.buffer()); + for (auto const& size : arg.sizes()) { + params.push_back(size.var); + } + for (auto const& stride : arg.strides()) { + params.push_back(stride.var); + } + } + for (auto& o : tensor_outputs) { + params.push_back(o); + } + + // Generate code. + std::string codegen_name; + switch (backend_type_) { + case kCudaCodeGen: + codegen_name = "cuda_codegen"; + break; + case kLLVMCodeGen: + codegen_name = "llvm_codegen"; + break; + case kSimpleIREval: + codegen_name = "simple_ir_eval"; + break; + default: + throw std::runtime_error( + "invalid backend type: " + + std::to_string(static_cast(backend_type_))); + } + + codegen_ = CreateCodeGen(codegen_name, stmt, params); +} + +void TensorExprKernel::PickAndCheckBackendType( + const at::ArrayRef& inputs) { + at::Device device = [&inputs]() { + for (auto const& input : inputs) { + if (input.isTensor()) { + return input.toTensor().device(); + } + } + throw std::runtime_error("No tensor inputs"); + }(); + BackendType backend_type = BackendType::kUninitialized; + if (device.type() == at::kCUDA) { + backend_type = kCudaCodeGen; + } else if (device.type() == at::kCPU) { +#ifdef ENABLE_LLVM + backend_type = kLLVMCodeGen; +#else + backend_type = kSimpleIREval; + ; +#endif + } else { + throw std::runtime_error("Invalid device type"); + } + + if (backend_type_ == kUninitialized) { + backend_type_ = backend_type; + device_ = device; + LowerToBackend(backend_type); + } else if (backend_type_ != backend_type) { + // TODO: if we have to support muliptole backends with the same subgraph, + // we need to add kernel caching. + throw std::runtime_error( + "Inconsistent backend_type: " + std::to_string(backend_type_) + " vs " + + std::to_string(backend_type)); + } +} + +void TensorExprKernel::CodeGenRun( + const std::vector& run_args) { + switch (backend_type_) { + case kSimpleIREval: + case kLLVMCodeGen: + case kCudaCodeGen: + codegen_->call(run_args); + break; + default: + throw std::runtime_error( + "Invalid backend type: " + std::to_string(backend_type_)); + } +} + +ExprHandle TensorExprKernel::createInputIndexExpr( + const Buffer& buffer, + const std::vector& axes, + const c10::VaryingShape& sizes, + const c10::VaryingStrides& strides, + const c10::VaryingStrides& contiguity, + const std::unordered_map& sizeVars) { + TORCH_CHECK( + axes.size() == strides.size(), "strides and axes are not the same size"); + + std::vector strideArgs; + std::vector sizeArgs; + ExprHandle stride = 1; + ExprHandle index = 0; + int n = axes.size() - 1; + + for (int i = 0; i < axes.size(); i++) { + // For discontiguous tensors, create a parameter to represent stride. + if (!*contiguity[i]) { + VarHandle v = VarHandle{ + "stride_" + buffer.data()->name_hint() + "_" + std::to_string(i), + kInt}; + strideArgs.emplace_back(n - i, v); + stride = v; + } + + // If size is dynamic (indicated by negative value) create a size param. + ExprHandle size; + auto sizeVal = *sizes[n - i]; + if (sizeVal < 0) { + auto it = sizeVars.find(sizeVal); + TORCH_CHECK(it != sizeVars.end()); + auto const& v = it->second; + sizeArgs.emplace_back(n - i, v); + size = v; + } else { + size = int32_t{sizeVal}; + } + + index = index + axes[n - i] * stride; + stride = stride * size; + } + + kernelArgs_.emplace_back(buffer, std::move(sizeArgs), std::move(strideArgs)); + return buffer(index); +} + +void TensorExprKernel::bindInput(const torch::jit::Value* input) { + auto const& t = input->type(); + switch (t->kind()) { + case TypeKind::TensorType: { + auto tt = input->type()->cast(); + Buffer in_buffer( + "t" + input->debugName(), + ToDtype(static_cast(*tt->scalarType())), + {0}); + std::vector inputTensorDims; + std::unordered_map sizeVars; + for (int i = 0; i < *tt->sizes().size(); i++) { + auto const& size = *tt->sizes()[i]; + if (size < 0) { + VarHandle v( + "size_" + std::to_string(input->unique()) + "_" + + std::to_string(i), + kInt); + sizeVars.emplace(size, v); + inputTensorDims.push_back(v); + } else { + inputTensorDims.push_back({int32_t{size}, "i" + std::to_string(i)}); + } + } +#ifdef DYNAMIC_SHAPES + tensors_.emplace( + input->unique(), + Compute( + "input", + inputTensorDims, + [&](const std::vector& axes) { + return createInputIndexExpr( + in_buffer, + axes, + tt->sizes(), + tt->strides(), + tt->contiguity(), + sizeVars); + })); +#else + auto const& strides = tt->strides(); + tensors_.emplace( + input->unique(), + Compute( + "input", + inputTensorDims, + [&](const std::vector& axes) { + std::vector idxs; + idxs.push_back(axes[0] * (int32_t)*strides[0]); + for (int i = 1; i < axes.size(); i++) { + idxs.push_back(idxs[i - 1] + axes[i] * (int32_t)*strides[i]); + } + return in_buffer(idxs.back()); + })); + kernelArgs_.emplace_back( + in_buffer, std::vector(), std::vector()); +#endif + break; + } + case TypeKind::FloatType: { + VarHandle v("v" + input->debugName(), kFloat); + kernelArgs_.push_back(v); + scalars_.emplace(input->unique(), v); + break; + } + case TypeKind::IntType: { + VarHandle v("v" + input->debugName(), kInt); + kernelArgs_.push_back(v); + scalars_.emplace(input->unique(), v); + break; + } + default: { + LOG(FATAL) << "Unhandled input type: " << *t; + break; + } + } +} + +TensorExprKernel::TensorExprKernel(const Graph& subgraph) { + KernelScope kernel_scope(&kernel_arena_); + + // Bind inputs to buffers. + n_inputs_ = subgraph.inputs().size(); + for (auto const& input : subgraph.inputs()) { + bindInput(input); + } + + // Bind nodes to tensor compute expressions. + for (auto const& n : subgraph.nodes()) { + if (n->kind() == prim::Constant || n->kind() == prim::ListConstruct) { + continue; + } else { + for (auto const& output : n->outputs()) { + if (output->hasUses()) { + tensors_.emplace(output->unique(), ComputeValue(output)); + } + } + } + } + + // Move output operands from `tensors_` to `tensor_outputs_` + for (const auto& output : subgraph.outputs()) { + CHECK(tensors_.count(output->unique())) << "Output must be a tensor"; + tensor_outputs_.emplace_back(tensors_.at(output->unique())); + tensors_.erase(output->unique()); + } +} + +void TensorExprKernel::run(Stack& stack) { + KernelScope kernel_scope(&kernel_arena_); + // Set up arguments (inputs, then outputs) for kernel call. + auto inputs = last(stack, n_inputs_); + PickAndCheckBackendType(inputs); + + std::map varToSize; + + std::vector run_args; + for (int i = 0; i < inputs.size(); i++) { + auto const& input = inputs[i]; + if (input.isInt()) { + run_args.push_back((int32_t)input.toInt()); + } else if (input.isDouble()) { + run_args.push_back((float)input.toDouble()); + } else if (input.isTensor()) { + auto const& tensor = input.toTensor(); + run_args.push_back(tensor.data_ptr()); + for (auto const& size : kernelArgs_[i].sizes()) { + int32_t s = tensor.sizes()[size.idx]; + run_args.push_back(s); + varToSize[size.var.node()] = s; + } + for (auto const& stride : kernelArgs_[i].strides()) { + int32_t s = tensor.strides()[stride.idx]; + run_args.push_back(s); + } + } + } + + std::vector outputs; + for (auto& o : tensor_outputs_) { + std::vector tensorSize; + for (const Expr* dim : o->dims()) { + auto it = varToSize.find(dim); + if (it != varToSize.end()) { + tensorSize.push_back(it->second); + } else { + const IntImm* s = dynamic_cast(dim); + TORCH_CHECK(s); + tensorSize.push_back(s->value()); + } + } + + outputs.push_back(at::empty( + tensorSize, c10::TensorOptions(tensorType(o)).device(device_))); + run_args.push_back(outputs.back().data_ptr()); + } + + // Call the kernel. + CodeGenRun(run_args); + + // Update the stack. + drop(stack, n_inputs_); + for (auto& o : outputs) { + push_one(stack, std::move(o)); + } +} diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h new file mode 100644 index 0000000000000..a87b776b9a865 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -0,0 +1,227 @@ +#pragma once + +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +TORCH_API std::vector texprDims(const torch::jit::Value* v); + +template +inline std::vector bufferSizes(const T& t) { + std::vector sizes; + for (int i = 0; i < t->function()->ndim(); i++) { + sizes.push_back( + dynamic_cast(t->function()->dim(i))->value()); + } + return sizes; +} + +template +inline std::vector computeIndicesToBroadcast( + const std::vector& output_axes, + const std::vector& input_sizes) { + TORCH_CHECK( + output_axes.size() >= input_sizes.size(), + "Cannot broadcast to a lower rank tensor"); + std::vector bcast; + auto axis_it = output_axes.rbegin(); + auto size_it = input_sizes.rbegin(); + while (size_it != input_sizes.rend()) { + auto const& size = size_it->AsNode(); + if (size && size->value() == 1) { + bcast.push_back(0); + } else { + bcast.push_back(*axis_it); + } + ++axis_it; + ++size_it; + } + std::reverse(bcast.begin(), bcast.end()); + return bcast; +} + +class TensorExprKernel { + public: + explicit TensorExprKernel(const Graph& subgraph); + + void run(Stack& stack); + + private: + enum BackendType { + kUninitialized, + kSimpleIREval, + kLLVMCodeGen, + kCudaCodeGen, + }; + + ExprHandle constant(const torch::jit::Value* v); + + template + ExprHandle broadcast(const T& t, const std::vector& axes) { + return t->call(computeIndicesToBroadcast( + axes, ExprVectorToExprHandleVector(t->function()->dims()))); + } + + template + ExprHandle chunk( + const T& t, + size_t chunk_idx, + size_t dim, + size_t chunks, + const std::vector& axes) { + auto sizes = bufferSizes(t); + size_t step = sizes[dim] / chunks; + + std::vector indices; + for (size_t i = 0; i < axes.size(); ++i) { + if (i == dim) { + indices.push_back(axes[i] + IntImm::make(chunk_idx * step)); + } else { + indices.push_back(axes[i]); + } + } + + return t->call(indices); + } + + std::vector valueShape(const torch::jit::Value* v); + + void promoteInputs(std::vector& inputs); + + ExprHandle demoteOutput(const ExprHandle& e, const torch::jit::Value* v); + + public: + template + ExprHandle tensorOrConstant( + const torch::jit::Value* v, + const std::vector& axes) { + auto ti = tensors_.find(v->unique()); + if (ti != tensors_.end()) { + return broadcast(ti->second, axes); + } + return constant(v); + } + + void addNoInline(int64_t unique_id); + inline Tensor* getTensor(int64_t unique_id) { + return tensors_.at(unique_id); + } + + private: + Tensor* ComputeOneOperand( + const std::string& name, + const torch::jit::Value* v, + std::function inner_expr); + + Tensor* ComputeTwoOperand( + const std::string& name, + const torch::jit::Value* v, + std::function + inner_expr); + + Tensor* ComputeTwoOperandWithAlpha( + const std::string& name, + const torch::jit::Value* v, + std::function + inner_expr); + + Tensor* ComputeThreeOperand( + const std::string& name, + const torch::jit::Value* v, + std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)> + inner_expr); + + Tensor* ComputeConditionWithTwoOperand( + const std::string& name, + const torch::jit::Value* v, + std::function< + ExprHandle(const ExprHandle&, const ExprHandle&, const ExprHandle&)> + inner_expr); + + Tensor* ComputeFourOperand( + const std::string& name, + const torch::jit::Value* v, + std::function inner_expr); + + Tensor* ComputeValue(const torch::jit::Value* v); + + void LowerToBackend(BackendType backend_type); + + void PickAndCheckBackendType(const at::ArrayRef& inputs); + + void CodeGenRun(const std::vector& run_args); + + void bindInput(const torch::jit::Value* input); + + ExprHandle createInputIndexExpr( + const Buffer& buffer, + const std::vector& axes, + const c10::VaryingShape& sizes, + const c10::VaryingStrides& strides, + const c10::VaryingStrides& contiguity, + const std::unordered_map& sizeVars); + + private: + struct ShapeArg { + size_t idx; + VarHandle var; + + ShapeArg(size_t i, VarHandle v) : idx(i), var(v) {} + }; + + struct KernelArg { + template + KernelArg(B&& b) : bufferArg_(std::forward(b)) {} + + template + KernelArg(B&& b, T&& sizes, T&& strides) + : bufferArg_(b), + sizeArgs_(std::forward(sizes)), + strideArgs_(std::forward(strides)) {} + + const CodeGen::BufferArg& buffer() const { + return bufferArg_; + } + + const std::vector& sizes() const { + return sizeArgs_; + } + + const std::vector& strides() const { + return strideArgs_; + } + + CodeGen::BufferArg bufferArg_; + std::vector sizeArgs_; + std::vector strideArgs_; + }; + + int64_t n_inputs_ = 0; + std::vector kernelArgs_; + std::vector tensor_outputs_; + std::unordered_map tensors_; + std::unordered_set no_inline_; + std::unordered_map scalars_; + std::unique_ptr codegen_; + KernelArena kernel_arena_; + BackendType backend_type_ = BackendType::kUninitialized; + at::Device device_ = at::kCPU; +}; + +TORCH_API int& GetTECudaPointwiseLoopLevels(); +TORCH_API int& GetTECudaPointwiseBlockCount(); +TORCH_API int& GetTECudaPointwiseBlockSize(); +TORCH_API void SetTexprFuserEnabled(bool val); + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp new file mode 100644 index 0000000000000..d9d677fff8fea --- /dev/null +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -0,0 +1,1158 @@ +#ifdef ENABLE_LLVM + +#include "torch/csrc/jit/tensorexpr/llvm_codegen.h" +#include "torch/csrc/jit/tensorexpr/native.h" + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "torch/csrc/jit/tensorexpr/buffer.h" +#include "torch/csrc/jit/tensorexpr/execution_counter.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_printer.h" +#include "torch/csrc/jit/tensorexpr/types.h" + +using namespace torch::jit::tensorexpr; + +DEFINE_TRIGGER(llvm_codegen_created); +DEFINE_TRIGGER(llvm_codegen_executed); + +static llvm::orc::JITTargetMachineBuilder makeTargetMachineBuilder() { +#if 0 + // FIXME: Switch to using detectHost() rather than setting up the JTMB manually + // once LLVM 10 is available. + return llvm::cantFail(llvm::orc::JITTargetMachineBuilder::detectHost()); +#else + llvm::orc::JITTargetMachineBuilder JTMB( + (llvm::Triple(llvm::sys::getProcessTriple()))); + + // Retrieve host CPU name and sub-target features and add them to builder. + // Relocation model, code model and codegen opt level are kept to default + // values. + llvm::SubtargetFeatures SubtargetFeatures; + llvm::StringMap FeatureMap; + llvm::sys::getHostCPUFeatures(FeatureMap); + for (auto& Feature : FeatureMap) { + SubtargetFeatures.AddFeature(Feature.first(), Feature.second); + } + + JTMB.setCodeGenOptLevel(llvm::CodeGenOpt::Default); + JTMB.setCPU(llvm::sys::getHostCPUName()); + JTMB.addFeatures(SubtargetFeatures.getFeatures()); + + return JTMB; +#endif +} + +LLVMCodeGen::LLVMCodeGen(Stmt* stmt) + : LLVMCodeGen(stmt, std::vector()) {} + +LLVMCodeGen::LLVMCodeGen( + Stmt* stmt, + const std::vector& args, + Dtype dtype) + : CodeGen(stmt, args), + context_(std::make_unique()), + irb_(getContext()) { + // Manually map types to LLVM types. + VoidTy_ = llvm::Type::getVoidTy(getContext()); + ByteTy_ = llvm::Type::getInt8Ty(getContext()); + CharTy_ = llvm::Type::getInt8Ty(getContext()); + ShortTy_ = llvm::Type::getInt16Ty(getContext()); + IntTy_ = llvm::Type::getInt32Ty(getContext()); + LongTy_ = llvm::Type::getInt64Ty(getContext()); + HalfTy_ = llvm::Type::getHalfTy(getContext()); + FloatTy_ = llvm::Type::getFloatTy(getContext()); + DoubleTy_ = llvm::Type::getDoubleTy(getContext()); + + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + + auto JTMB = makeTargetMachineBuilder(); + TM_ = llvm::cantFail(JTMB.createTargetMachine()); + + jit_ = std::make_unique(); + module_ = std::make_unique("pytorch", getContext()); + + module_->setDataLayout(cantFail(JTMB.getDefaultDataLayoutForTarget())); + module_->setTargetTriple(JTMB.getTargetTriple().str()); + + // Emit prototype and bind argument Vars to parameter indices. + llvm::Type* retTy = dtypeToLLVM(dtype); + std::vector params; + for (int i = 0; i < args.size(); i++) { + auto const& arg = args[i]; + if (arg.isVar()) { + params.push_back(dtypeToLLVM(arg.dtype())); + } else { + params.push_back(dtypeToLLVMPtr(arg.dtype())); + } + varToArg_[arg.var()] = i; + } + llvm::FunctionType* fntype = llvm::FunctionType::get(retTy, params, false); + fn_ = llvm::Function::Create( + fntype, llvm::Function::PrivateLinkage, "pytorch", module_.get()); + for (int i = 0; i < args.size(); i++) { + if (!args[i].isVar()) { + fn_->addParamAttr(i, llvm::Attribute::NoAlias); + } + } + + emitWrapper(params); + emitKernel(stmt, params); + + cantFail(jit_->addModule( + llvm::orc::ThreadSafeModule(std::move(module_), context_))); + auto sym = jit_->findSymbol("wrapper"); + kernelAddress_ = cantFail(sym.getAddress()); + + USE_TRIGGER(llvm_codegen_created); +} + +llvm::LLVMContext& LLVMCodeGen::getContext() { + return *context_.getContext(); +} + +llvm::Type* LLVMCodeGen::dtypeToLLVM(Dtype dtype) { + switch (dtype.scalar_type()) { +#define TYPE_CASE(_1, n) \ + case ScalarType::n: \ + return n##Ty_; \ + break; + + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + LOG(FATAL) << "Unhandled dtype: " << dtype; + } + return nullptr; +} + +llvm::Type* LLVMCodeGen::dtypeToLLVMPtr(Dtype dtype) { + return dtypeToLLVM(dtype)->getPointerTo(); +} + +void LLVMCodeGen::emitWrapper(const std::vector& params) { + auto voidPtrPtrTy = llvm::Type::getInt8PtrTy(getContext())->getPointerTo(); + auto wrapper = llvm::Function::Create( + llvm::FunctionType::get(IntTy_, {voidPtrPtrTy}, false), + llvm::Function::ExternalLinkage, + "wrapper", + module_.get()); + auto wrapBB = llvm::BasicBlock::Create(getContext(), "wrapBB", wrapper); + irb_.SetInsertPoint(wrapBB); + llvm::SmallVector wrappedArgs; + for (size_t i = 0; i < params.size(); i++) { + auto argp = irb_.CreateGEP( + wrapper->arg_begin(), llvm::ConstantInt::getSigned(IntTy_, i)); + if (params[i]->isPointerTy()) { + auto arg = irb_.CreatePointerCast(irb_.CreateLoad(argp), params[i]); + wrappedArgs.push_back(arg); + } else { + auto p = irb_.CreatePointerCast( + irb_.CreateLoad(argp), params[i]->getPointerTo()); + auto arg = irb_.CreateLoad(p); + wrappedArgs.push_back(arg); + } + } + auto cc = irb_.CreateCall(fn_, wrappedArgs); + irb_.CreateRet(cc); +} + +void LLVMCodeGen::emitKernel( + Stmt* stmt, + const std::vector& params) { + // Set insert point to the real function. + bb_ = llvm::BasicBlock::Create(getContext(), "entry", fn_); + irb_.SetInsertPoint(bb_); + + // Compile the kernel. + stmt->accept(this); + irb_.CreateRet(value_); + +#if DEBUG_PRINT + llvm::errs() << *module_; +#endif + CHECK(!llvm::verifyFunction(*fn_, &llvm::outs())) + << "Function verification failed"; + optimize(*module_); + +#if DEBUG_PRINT + llvm::errs() << *module_; + llvm::SmallVector asmBuffer; + llvm::raw_svector_ostream asmStream(asmBuffer); + llvm::legacy::PassManager PM; + TM_->addPassesToEmitFile( + PM, + asmStream, + nullptr, + llvm::TargetMachine::CodeGenFileType::CGFT_AssemblyFile); + PM.run(*module_); + llvm::errs() << asmStream.str(); +#endif +} + +static void* argToPtr( + const CodeGen::BufferArg& bufferArg, + const CodeGen::CallArg& callArg) { + if (!bufferArg.isVar()) { + return callArg.data(); + } + + switch (bufferArg.dtype().scalar_type()) { +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + return callArg.Name##Ptr(); + break; + + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + + default: + LOG(FATAL) << "Unhandled dtype for arg: " << bufferArg.var()->name_hint() + << "dtype=" << bufferArg.var()->dtype(); + } + return nullptr; +} + +void LLVMCodeGen::call(const std::vector& args) { + CHECK_EQ(args.size(), buffer_args().size()) + << "args: " << args.size() << ", buffers: " << buffer_args().size(); + for (size_t i = 0; i < buffer_args().size(); i++) { + auto const& bufferArg = buffer_args()[i]; + auto const& callArg = args[i]; + args_.push_back(argToPtr(bufferArg, callArg)); + } + value(args_); + args_.clear(); + USE_TRIGGER(llvm_codegen_executed); +} + +// TODO: The binary ops are copypasta. + +void LLVMCodeGen::visit(const Add* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + // TODO: Handle arg promotion. + if (lfp && rfp) { + value_ = irb_.CreateFAdd(lhs, rhs); + } else if (!lfp && !rfp) { + value_ = irb_.CreateAdd(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch add arg types"; + } +} + +void LLVMCodeGen::visit(const Sub* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + // TODO: Handle arg promotion. + if (lfp && rfp) { + value_ = irb_.CreateFSub(lhs, rhs); + } else if (!lfp && !rfp) { + value_ = irb_.CreateSub(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch sub arg types"; + } +} + +void LLVMCodeGen::visit(const Mul* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + // TODO: Handle arg promotion. + if (lfp && rfp) { + value_ = irb_.CreateFMul(lhs, rhs); + } else if (!lfp && !rfp) { + value_ = irb_.CreateMul(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch mul arg types, lhs is " + << (lfp ? "" : "not ") << "floating point, whereas rhs is " + << (rfp ? "" : "not "); + } +} + +void LLVMCodeGen::visit(const Div* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + // TODO: Handle arg promotion. + if (lfp && rfp) { + value_ = irb_.CreateFDiv(lhs, rhs); + } else if (!lfp && !rfp) { + value_ = irb_.CreateSDiv(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch div arg types"; + } +} + +void LLVMCodeGen::visit(const And* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + if (!lfp && !rfp) { + value_ = irb_.CreateAnd(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch And arg types"; + } +} + +void LLVMCodeGen::visit(const Xor* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + if (!lfp && !rfp) { + value_ = irb_.CreateXor(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch And arg types"; + } +} + +void LLVMCodeGen::visit(const Lshift* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + if (!lfp && !rfp) { + value_ = irb_.CreateShl(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch And arg types"; + } +} + +void LLVMCodeGen::visit(const Rshift* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + bool lfp = lhs->getType()->isFloatingPointTy(); + v->rhs()->accept(this); + auto rhs = this->value_; + bool rfp = rhs->getType()->isFloatingPointTy(); + + if (!lfp && !rfp) { + value_ = irb_.CreateLShr(lhs, rhs); + } else { + LOG(FATAL) << "Unhandled mismatch And arg types"; + } +} + +void LLVMCodeGen::visit(const Mod* v) { + throw std::runtime_error("Mod unsupported in LLVM codegen yet"); +} + +void LLVMCodeGen::visit(const Max* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + v->rhs()->accept(this); + auto rhs = this->value_; + + if (v->dtype() == kInt) { + auto icmp = irb_.CreateICmpSGT(lhs, rhs); + value_ = irb_.CreateSelect(icmp, lhs, rhs); + return; + } + + if (v->propagate_nans()) { + value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::maximum, lhs, rhs); + return; + } + + value_ = irb_.CreateSelect( + irb_.CreateFCmp(llvm::FCmpInst::FCMP_OGT, lhs, rhs), lhs, rhs); +} + +void LLVMCodeGen::visit(const Min* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + v->rhs()->accept(this); + auto rhs = this->value_; + + if (v->dtype() == kInt) { + auto icmp = irb_.CreateICmpSLT(lhs, rhs); + value_ = irb_.CreateSelect(icmp, lhs, rhs); + return; + } + + if (v->propagate_nans()) { + value_ = irb_.CreateBinaryIntrinsic(llvm::Intrinsic::minimum, lhs, rhs); + return; + } + + value_ = irb_.CreateSelect( + irb_.CreateFCmp(llvm::FCmpInst::FCMP_OLT, lhs, rhs), lhs, rhs); +} + +void LLVMCodeGen::visit(const CompareSelect* v) { + v->lhs()->accept(this); + auto lhs = this->value_; + v->rhs()->accept(this); + auto rhs = this->value_; + v->ret_val1()->accept(this); + auto retval1 = this->value_; + v->ret_val2()->accept(this); + auto retval2 = this->value_; + + auto type_used = v->lhs()->dtype().scalar_type(); + + llvm::Value* cmp_; + CompareSelectOperation cmp_op_ = v->compare_select_op(); + + if (is_integral(type_used)) { + switch (cmp_op_) { + case CompareSelectOperation::kEQ: + cmp_ = irb_.CreateICmpEQ(lhs, rhs); + break; + case CompareSelectOperation::kNE: + cmp_ = irb_.CreateICmpNE(lhs, rhs); + break; + case CompareSelectOperation::kGT: + cmp_ = irb_.CreateICmpSGT(lhs, rhs); + break; + case CompareSelectOperation::kGE: + cmp_ = irb_.CreateICmpSGE(lhs, rhs); + break; + case CompareSelectOperation::kLT: + cmp_ = irb_.CreateICmpSLT(lhs, rhs); + break; + case CompareSelectOperation::kLE: + cmp_ = irb_.CreateICmpSLE(lhs, rhs); + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } else if (is_floating_point(type_used)) { // FP32 + switch (cmp_op_) { + case CompareSelectOperation::kEQ: + cmp_ = irb_.CreateFCmpOEQ(lhs, rhs); + break; + case CompareSelectOperation::kNE: + cmp_ = irb_.CreateFCmpONE(lhs, rhs); + break; + case CompareSelectOperation::kGT: + cmp_ = irb_.CreateFCmpOGT(lhs, rhs); + break; + case CompareSelectOperation::kGE: + cmp_ = irb_.CreateFCmpOGE(lhs, rhs); + break; + case CompareSelectOperation::kLT: + cmp_ = irb_.CreateFCmpOLT(lhs, rhs); + break; + case CompareSelectOperation::kLE: + cmp_ = irb_.CreateFCmpOLE(lhs, rhs); + break; + default: + // TODO: change to a proper error report + throw std::runtime_error("invalid operator type"); + } + } else { + throw std::runtime_error("invalid type for CompareSelect"); + } + + value_ = irb_.CreateSelect(cmp_, retval1, retval2); + return; +} + +template +typename std::enable_if::value, llvm::Value*>::type +getFromType(llvm::Type* type, T value) { + return llvm::ConstantInt::get(type, value, std::is_signed::value); +} + +template +typename std::enable_if::value, llvm::Value*>::type +getFromType(llvm::Type* type, T value) { + return llvm::ConstantFP::get(type, value); +} + +#define IMM_VISIT_DECLARE(Type, Name) \ + void LLVMCodeGen::visit(const Name##Imm* v) { \ + value_ = getFromType(Name##Ty_, v->value()); \ + } +AT_FORALL_SCALAR_TYPES(IMM_VISIT_DECLARE); +#undef IMM_VISIT_DECLARE + +void LLVMCodeGen::visit(const HalfImm* v) { + value_ = llvm::ConstantFP::get(HalfTy_, v->value()); +} + +void LLVMCodeGen::visit(const BoolImm* v) { + value_ = llvm::ConstantInt::get(BoolTy_, v->value()); +} + +void LLVMCodeGen::visit(const Cast* v) { + v->src_value()->accept(this); + + llvm::Type* dstType = dtypeToLLVM(v->dtype()); + if (v->dtype().lanes() > 1) { + dstType = llvm::VectorType::get(dstType, v->dtype().lanes()); + } + llvm::Type* srcType = dtypeToLLVM(v->src_value()->dtype()); + + if (srcType == dstType) { + // do nothing. + return; + } + + bool destUnsigned = v->dtype().scalar_type() == ScalarType::Byte; + + // Scalar casts + if (srcType->isFloatingPointTy()) { + if (dstType->isFloatingPointTy()) { + value_ = irb_.CreateFPCast(value_, dstType); + } else if (dstType->isIntegerTy()) { + if (destUnsigned) { + value_ = irb_.CreateFPToUI(value_, dstType); + } else { + value_ = irb_.CreateFPToSI(value_, dstType); + } + } else { + LOG(FATAL) << "Unsupported cast!"; + } + } else if (srcType->isIntegerTy()) { + if (dstType->isFloatingPointTy()) { + if (destUnsigned) { + value_ = irb_.CreateUIToFP(value_, dstType); + } else { + value_ = irb_.CreateSIToFP(value_, dstType); + } + } else if (dstType->isIntegerTy()) { + value_ = irb_.CreateIntCast(value_, dstType, !destUnsigned); + } else { + LOG(FATAL) << "Unsupported cast!"; + } + } +} + +void LLVMCodeGen::visit(const Var* v) { + if (varToArg_.count(v)) { + auto idx = varToArg_.at(v); + auto arg = fn_->arg_begin() + idx; + value_ = arg; + } else if (varToVal_.count(v)) { + value_ = varToVal_.at(v); + } else { + LOG(FATAL) << "Unable to resolve Variable " << *v << "\n"; + } +} + +void LLVMCodeGen::visit(const Let* v) { + const Var* var = dynamic_cast(v->var()); + CHECK(var != nullptr); + v->value()->accept(this); + auto value = value_; + if (!varToVal_.count(var)) { + varToVal_.emplace(var, value); + } else { + throw std::runtime_error("var should not exist before"); + } + v->body()->accept(this); + if (varToVal_.count(var)) { + varToVal_.erase(var); + } else { + throw std::runtime_error("erasing var that doesn't exist"); + } +} + +// TODO: refactor this and merge with Let +void LLVMCodeGen::visit(const LetStmt* v) { + const Var* var = v->var(); + CHECK(var != nullptr); + v->value()->accept(this); + auto value = value_; + if (!varToVal_.count(var)) { + varToVal_.emplace(var, value); + } else { + throw std::runtime_error("var should not exist before"); + } + v->body()->accept(this); + if (varToVal_.count(var)) { + varToVal_.erase(var); + } else { + throw std::runtime_error("erasing var that doesn't exist"); + } +} + +void LLVMCodeGen::visit(const Ramp* v) { + v->base()->accept(this); + auto base = this->value_; + v->stride()->accept(this); + auto stride = this->value_; + int lanes = v->lanes(); + + llvm::Type* vecType = nullptr; + switch (v->dtype().scalar_type()) { +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + vecType = llvm::VectorType::get(Name##Ty_, lanes); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw std::runtime_error("invalid dtype in Ramp"); + } + + value_ = llvm::UndefValue::get(vecType); + for (int i = 0; i < lanes; ++i) { + value_ = irb_.CreateInsertElement(value_, base, i); + base = irb_.CreateAdd(base, stride); + } +} + +llvm::Value* LLVMCodeGen::emitUnmaskedLoad( + llvm::Value* base, + llvm::Value* idx) { + auto addr = irb_.CreateGEP(base, idx); + return irb_.CreateLoad(addr); +} + +llvm::Value* LLVMCodeGen::emitMaskedLoad( + llvm::Value* base, + llvm::Value* idx, + llvm::Value* mask) { + // Create block structure for the masked load. + auto preheader = irb_.GetInsertBlock(); + auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_); + auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_); + + // Test the mask + auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1)); + irb_.CreateCondBr(cond, condblock, tailblock); + + // Do the load + irb_.SetInsertPoint(condblock); + auto addr = irb_.CreateGEP(base, idx); + auto load = irb_.CreateLoad(addr); + irb_.CreateBr(tailblock); + + // Merge the masked and unmasked CFG edges + irb_.SetInsertPoint(tailblock); + auto phi = irb_.CreatePHI(load->getType(), 2); + phi->addIncoming(llvm::UndefValue::get(load->getType()), preheader); + phi->addIncoming(load, condblock); + + return phi; +} + +void LLVMCodeGen::visit(const Load* v) { + v->base_handle()->accept(this); + auto base = this->value_; + v->index()->accept(this); + auto idx = this->value_; + v->mask()->accept(this); + auto mask = this->value_; + + if (v->dtype().lanes() == 1) { + auto* maskimm = dynamic_cast(v->mask()); + if (maskimm && maskimm->value() == 1) { + value_ = emitUnmaskedLoad(base, idx); + } else { + value_ = emitMaskedLoad(base, idx, mask); + } + return; + } + + llvm::Type* loadType = nullptr; + + switch (v->dtype().scalar_type()) { +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + loadType = llvm::VectorType::get(Name##Ty_, v->dtype().lanes()); \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw std::runtime_error("invalid dtype in Load"); + } + + // Detect whether the vector mask is all true + bool unmasked_load = false; + auto* mask_broadcast = dynamic_cast(v->mask()); + if (mask_broadcast) { + auto* broadcast_imm = dynamic_cast(mask_broadcast->value()); + if (broadcast_imm && broadcast_imm->value() == 1) { + unmasked_load = true; + } + } + + // Handle the case where the load is contiguous and unmasked efficiently + auto* idx_ramp = dynamic_cast(v->index()); + if (unmasked_load && idx_ramp) { + auto* stride_imm = dynamic_cast(idx_ramp->stride()); + if (stride_imm && stride_imm->value() == 1) { + auto first_idx = irb_.CreateExtractElement(idx, uint64_t{0ULL}); + auto addr = irb_.CreateGEP(base, first_idx); + auto vaddr = irb_.CreateBitOrPointerCast( + addr, llvm::PointerType::get(loadType, 0)); + value_ = irb_.CreateAlignedLoad(loadType, vaddr, 4); + return; + } + } + + // Fallback to a scalar implementation + llvm::Value* load = llvm::UndefValue::get(loadType); + for (int i = 0; i < v->dtype().lanes(); ++i) { + auto sub_idx = irb_.CreateExtractElement(idx, i); + llvm::Value* sub_load = nullptr; + if (unmasked_load) { + sub_load = emitUnmaskedLoad(base, sub_idx); + } else { + auto sub_mask = irb_.CreateExtractElement(mask, i); + sub_load = emitMaskedLoad(base, sub_idx, sub_mask); + } + load = irb_.CreateInsertElement(load, sub_load, i); + } + + value_ = load; +} + +void LLVMCodeGen::visit(const For* v) { + // Create "start" value. + v->start()->accept(this); + auto start = this->value_; + + // Create loop preheader and body. + auto preheader = irb_.GetInsertBlock(); + auto loop = llvm::BasicBlock::Create(getContext(), "loop", fn_); + irb_.CreateBr(loop); + irb_.SetInsertPoint(loop); + + // Set up phi node for index variable. + auto idx = irb_.CreatePHI(IntTy_, 2); + idx->addIncoming(start, preheader); + varToVal_.emplace(v->var(), idx); + + // Codegen the body. + if (v->body()) { + v->body()->accept(this); + } + + // Create the stop condition. and "after" block. + auto inc = irb_.CreateAdd(idx, llvm::ConstantInt::getSigned(IntTy_, 1)); + v->stop()->accept(this); + auto stop = this->value_; + auto cond = irb_.CreateICmpSLT(inc, stop); + + // Branch back to top of loop and finish phi for index variable. + auto end_loop = irb_.GetInsertBlock(); + auto after = llvm::BasicBlock::Create(getContext(), "after", fn_); + irb_.CreateCondBr(cond, loop, after); + irb_.SetInsertPoint(after); + idx->addIncoming(inc, end_loop); + value_ = llvm::ConstantInt::get(IntTy_, 0); +} + +void LLVMCodeGen::visit(const Block* v) { + for (int i = 0; i < v->nstmts(); i++) { + v->stmt(i)->accept(this); + } +} + +void LLVMCodeGen::emitUnmaskedStore( + llvm::Value* base, + llvm::Value* idx, + llvm::Value* val) { + auto addr = irb_.CreateGEP(base, idx); + irb_.CreateStore(val, addr); +} + +void LLVMCodeGen::emitMaskedStore( + llvm::Value* base, + llvm::Value* idx, + llvm::Value* mask, + llvm::Value* val) { + // Create block structure for the masked store. + auto preheader = irb_.GetInsertBlock(); + auto condblock = llvm::BasicBlock::Create(getContext(), "cond", fn_); + auto tailblock = llvm::BasicBlock::Create(getContext(), "tail", fn_); + + // Test the mask + auto cond = irb_.CreateICmpEQ(mask, llvm::ConstantInt::get(IntTy_, 1)); + irb_.CreateCondBr(cond, condblock, tailblock); + + // Do the store + irb_.SetInsertPoint(condblock); + auto addr = irb_.CreateGEP(base, idx); + irb_.CreateStore(val, addr); + irb_.CreateBr(tailblock); + + // Merge the masked and unmasked CFG edges + irb_.SetInsertPoint(tailblock); +} + +void LLVMCodeGen::visit(const Store* v) { + v->base_handle()->accept(this); + auto base = this->value_; + v->index()->accept(this); + auto idx = this->value_; + v->mask()->accept(this); + auto mask = this->value_; + v->value()->accept(this); + auto val = this->value_; + + value_ = llvm::ConstantInt::get(IntTy_, 0); + + if (v->value()->dtype().lanes() == 1) { + auto* maskimm = dynamic_cast(v->mask()); + if (maskimm && maskimm->value() == 1) { + emitUnmaskedStore(base, idx, val); + } else { + emitMaskedStore(base, idx, mask, val); + } + return; + } + + // Detect whether the vector mask is all true + bool unmasked_store = false; + auto* mask_broadcast = dynamic_cast(v->mask()); + if (mask_broadcast) { + auto* broadcast_imm = dynamic_cast(mask_broadcast->value()); + if (broadcast_imm && broadcast_imm->value() == 1) { + unmasked_store = true; + } + } + + // Handle the case where the store is contiguous and unmasked efficiently + auto* idx_ramp = dynamic_cast(v->index()); + if (unmasked_store && idx_ramp) { + auto* stride_imm = dynamic_cast(idx_ramp->stride()); + if (stride_imm && stride_imm->value() == 1) { + auto first_idx = irb_.CreateExtractElement(idx, uint64_t{0}); + auto addr = irb_.CreateGEP(base, first_idx); + auto vaddr = irb_.CreateBitOrPointerCast( + addr, llvm::PointerType::get(val->getType(), 0)); + irb_.CreateAlignedStore(val, vaddr, 4); + return; + } + } + + // Fallback to a scalar implementation + for (int i = 0; i < v->value()->dtype().lanes(); ++i) { + auto sub_idx = irb_.CreateExtractElement(idx, i); + auto sub_val = irb_.CreateExtractElement(val, i); + if (unmasked_store) { + emitUnmaskedStore(base, sub_idx, sub_val); + } else { + auto sub_mask = irb_.CreateExtractElement(mask, i); + emitMaskedStore(base, sub_idx, sub_mask, sub_val); + } + } +} + +void LLVMCodeGen::visit(const Broadcast* v) { + v->value()->accept(this); + int lanes = v->lanes(); + value_ = irb_.CreateVectorSplat(lanes, value_); +} + +void LLVMCodeGen::visit(const IfThenElse* v) { + v->condition()->accept(this); + llvm::Value* condition = value_; + llvm::Value* c = + irb_.CreateICmpNE(condition, llvm::ConstantInt::get(IntTy_, 0)); + + auto then_block = llvm::BasicBlock::Create(getContext(), "then", fn_); + auto else_block = llvm::BasicBlock::Create(getContext(), "else", fn_); + auto end_block = llvm::BasicBlock::Create(getContext(), "block", fn_); + irb_.CreateCondBr(c, then_block, else_block); + + irb_.SetInsertPoint(then_block); + v->true_value()->accept(this); + llvm::Value* then_val = value_; + then_block = irb_.GetInsertBlock(); + irb_.CreateBr(end_block); + + irb_.SetInsertPoint(else_block); + v->false_value()->accept(this); + llvm::Value* else_val = value_; + else_block = irb_.GetInsertBlock(); + irb_.CreateBr(end_block); + + irb_.SetInsertPoint(end_block); + llvm::PHINode* phi = irb_.CreatePHI(then_val->getType(), 2); + phi->addIncoming(then_val, then_block); + phi->addIncoming(else_val, else_block); + value_ = phi; +} + +void LLVMCodeGen::visit(const BaseCallNode* v) { + LOG(FATAL) << "Unimplemented: BaseCall"; +} + +static void applyMathFunctionAttributes(llvm::Function* f) { + f->addFnAttr(llvm::Attribute::ReadNone); + f->addFnAttr(llvm::Attribute::NoFree); + f->addFnAttr(llvm::Attribute::NoUnwind); + f->addFnAttr(llvm::Attribute::Speculatable); + f->addFnAttr(llvm::Attribute::WillReturn); +} + +void LLVMCodeGen::visit(const CallExternal* v) { + LOG(FATAL) << "CallExternal needs to be lowered to OpaqueCall"; +} + +void LLVMCodeGen::visit(const Intrinsics* v) { + llvm::FunctionType* call_ty = nullptr; + llvm::Value* call_fn = nullptr; + + switch (v->op_type()) { +#define UNARY_INTRIN_CASE(enum, intrin) \ + case enum: { \ + v->params().front()->accept(this); \ + value_ = irb_.CreateUnaryIntrinsic(intrin, value_); \ + return; \ + } break; + UNARY_INTRIN_CASE(kLog10, llvm::Intrinsic::log10) + UNARY_INTRIN_CASE(kLog, llvm::Intrinsic::log) + UNARY_INTRIN_CASE(kLog2, llvm::Intrinsic::log2) + UNARY_INTRIN_CASE(kExp, llvm::Intrinsic::exp) + UNARY_INTRIN_CASE(kCos, llvm::Intrinsic::cos) + UNARY_INTRIN_CASE(kSin, llvm::Intrinsic::sin) + UNARY_INTRIN_CASE(kSqrt, llvm::Intrinsic::sqrt) + UNARY_INTRIN_CASE(kFabs, llvm::Intrinsic::fabs) + UNARY_INTRIN_CASE(kFloor, llvm::Intrinsic::floor) + UNARY_INTRIN_CASE(kCeil, llvm::Intrinsic::ceil) + UNARY_INTRIN_CASE(kTrunc, llvm::Intrinsic::trunc) + UNARY_INTRIN_CASE(kRound, llvm::Intrinsic::round) +#undef UNARY_INTRIN_CASE + + case kRsqrt: { + v->params().front()->accept(this); + value_ = irb_.CreateUnaryIntrinsic(llvm::Intrinsic::sqrt, value_); + llvm::Value* constant = llvm::ConstantFP::get(FloatTy_, 1.0); + if (v->dtype().lanes() > 1) { + constant = irb_.CreateVectorSplat(v->dtype().lanes(), constant); + } + value_ = irb_.CreateFDiv(constant, value_); + return; + } break; + +#define UNARY_MATH_CASE(enum, name, type) \ + case enum: { \ + auto callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type}, false), {}); \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ + } break; + UNARY_MATH_CASE(kErf, "erff", FloatTy_) + UNARY_MATH_CASE(kErfc, "erfcf", FloatTy_) + UNARY_MATH_CASE(kTan, "tanf", FloatTy_) + UNARY_MATH_CASE(kAcos, "acosf", FloatTy_) + UNARY_MATH_CASE(kAsin, "asinf", FloatTy_) + UNARY_MATH_CASE(kAtan, "atanf", FloatTy_) + UNARY_MATH_CASE(kCosh, "coshf", FloatTy_) + UNARY_MATH_CASE(kSinh, "sinhf", FloatTy_) + UNARY_MATH_CASE(kTanh, "tanhf", FloatTy_) + UNARY_MATH_CASE(kExpm1, "expm1f", FloatTy_) + UNARY_MATH_CASE(kLgamma, "lgammaf", FloatTy_) +#undef UNARY_MATH_CASE + +#define BINARY_MATH_CASE(enum, name, type) \ + case enum: { \ + auto callee = module_->getOrInsertFunction( \ + name, llvm::FunctionType::get(type, {type, type}, false), {}); \ + call_ty = callee.getFunctionType(); \ + call_fn = callee.getCallee(); \ + applyMathFunctionAttributes(llvm::cast(call_fn)); \ + } break; + BINARY_MATH_CASE(kRemainder, "remainderf", FloatTy_) + BINARY_MATH_CASE(kAtan2, "atan2f", FloatTy_) + BINARY_MATH_CASE(kPow, "powf", FloatTy_) + BINARY_MATH_CASE(kFmod, "fmodf", FloatTy_) +#undef BINARY_MATH_CASE + + default: { + LOG(FATAL) << "Unimplemented: Intrinsics: " << ExprHandle(v); + } break; + } + + std::vector params; + for (auto& p : v->params()) { + p->accept(this); + params.push_back(value_); + } + + if (v->dtype().lanes() == 1) { + value_ = irb_.CreateCall(call_ty, call_fn, params); + } else { + llvm::Type* vecType = llvm::VectorType::get(FloatTy_, v->dtype().lanes()); + value_ = llvm::UndefValue::get(vecType); + for (int i = 0; i < v->dtype().lanes(); ++i) { + std::vector call_operands; + for (auto p : params) { + call_operands.push_back(irb_.CreateExtractElement(p, i)); + } + + llvm::Value* val = irb_.CreateCall(call_ty, call_fn, call_operands); + value_ = irb_.CreateInsertElement(value_, val, i); + } + } +} + +void LLVMCodeGen::visit(const FunctionCall* v) { + LOG(FATAL) << "Unimplemented: FunctionCall"; +} + +void LLVMCodeGen::visit(const Allocate* v) { + const Var* buffer_var = v->buffer_var(); + std::vector dims = v->dims(); + auto total_byte_size = ExprHandle(IntImm::make(v->dtype().byte_size())); + + for (size_t i = 0; i < dims.size(); i++) { + total_byte_size = total_byte_size * ExprHandle(dims[i]); + } + total_byte_size.node()->accept(this); + auto byte_size = irb_.CreateZExt(value_, LongTy_); + auto f = module_->getOrInsertFunction( + "malloc", + llvm::FunctionType::get( + llvm::PointerType::getUnqual(CharTy_), {LongTy_}, false)); + TORCH_INTERNAL_ASSERT(f); + auto call_ty = f.getFunctionType(); + auto call_fn = f.getCallee(); + value_ = irb_.CreateCall(call_ty, call_fn, {byte_size}); + llvm::Type* loadType = nullptr; + + switch (v->dtype().scalar_type()) { +#define TYPE_CASE(_1, Name) \ + case ScalarType::Name: \ + loadType = Name##Ty_; \ + break; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw std::runtime_error("invalid dtype in Load"); + } + + auto vaddr = + irb_.CreateBitOrPointerCast(value_, llvm::PointerType::get(loadType, 0)); + + varToVal_.emplace(buffer_var, vaddr); + return; +} + +void LLVMCodeGen::visit(const OpaqueCall* v) { + auto nfr = getNativeFunctionRegistry(); + TORCH_CHECK( + nfr.find(v->name()) != nfr.end(), + v->name(), + " never registered with native function registry. See tensorexpr/native.h"); + auto sym = jit_->findSymbol(jit_->mangle(v->name())); + + std::vector params; + std::vector types; + + for (auto& p : v->input_handles()) { + p->accept(this); + auto t = value_->getType(); + types.push_back(t); + params.push_back(value_); + } + + for (auto& p : v->arguments()) { + p->accept(this); + auto t = value_->getType(); + types.push_back(t); + params.push_back(value_); + } + + v->output_handle()->accept(this); + params.push_back(value_); + types.push_back(llvm::PointerType::getUnqual(FloatTy_)); + + auto f = module_->getOrInsertFunction( + jit_->mangle(v->name()), llvm::FunctionType::get(VoidTy_, types, false)); + TORCH_INTERNAL_ASSERT(f); + auto call_ty = f.getFunctionType(); + auto call_fn = f.getCallee(); + value_ = irb_.CreateCall(call_ty, call_fn, params); +} + +void LLVMCodeGen::visit(const Free* v) { + const Var* buffer_var = v->buffer_var(); + auto f = module_->getOrInsertFunction( + "free", + llvm::FunctionType::get( + VoidTy_, {llvm::PointerType::getUnqual(CharTy_)}, false)); + TORCH_INTERNAL_ASSERT(f); + auto call_ty = f.getFunctionType(); + auto call_fn = f.getCallee(); + auto addr = varToVal_.at(buffer_var); + addr = + irb_.CreateBitOrPointerCast(addr, llvm::PointerType::getUnqual(CharTy_)); + irb_.CreateCall(call_ty, call_fn, {addr}); + return; +} + +void LLVMCodeGen::visit(const Cond* v) { + LOG(FATAL) << "Unimplemented: Cond"; +} + +void LLVMCodeGen::optimize(llvm::Module& M) { + llvm::legacy::FunctionPassManager FPM(&M); + llvm::legacy::PassManager PM; + + // Add internal analysis passes from the target machine. + PM.add( + llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis())); + FPM.add( + llvm::createTargetTransformInfoWrapperPass(TM_->getTargetIRAnalysis())); + + llvm::PassManagerBuilder PMB; + PMB.OptLevel = 3; + PMB.LoopVectorize = true; + PMB.SLPVectorize = true; + TM_->adjustPassManager(PMB); + + PMB.populateFunctionPassManager(FPM); + PMB.populateModulePassManager(PM); + FPM.doInitialization(); + PM.run(M); + for (auto& FF : M) { + FPM.run(FF); + } + FPM.doFinalization(); + PM.run(M); +} + +RegisterCodeGen reg("llvm_codegen"); + +#endif // ENABLE_LLVM diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.h b/torch/csrc/jit/tensorexpr/llvm_codegen.h new file mode 100644 index 0000000000000..6fa511db310ad --- /dev/null +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.h @@ -0,0 +1,137 @@ +#pragma once + +#ifdef ENABLE_LLVM +#include + +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "torch/csrc/jit/tensorexpr/codegen.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/ir_visitor.h" +#include "torch/csrc/jit/tensorexpr/llvm_jit.h" + +#include +#include +#include +#include +#include + +#define DEBUG_PRINT 0 + +#if DEBUG_PRINT +#include +#endif + +namespace torch { +namespace jit { +namespace tensorexpr { + +class TORCH_API LLVMCodeGen : public CodeGen, public IRVisitor { + private: + llvm::orc::ThreadSafeContext context_; + llvm::IRBuilder<> irb_; + std::unique_ptr TM_; + std::unique_ptr jit_; + std::unique_ptr module_; + llvm::Function* fn_; + llvm::BasicBlock* bb_; + llvm::Value* value_; + llvm::JITTargetAddress kernelAddress_; + + llvm::Type* VoidTy_; +#define LLVM_TYPE_DECLARE(_1, Name) llvm::Type* Name##Ty_; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, LLVM_TYPE_DECLARE); +#undef LLVM_TYPE_DECLARE + + std::unordered_map varToArg_; + std::unordered_map varToVal_; + + std::vector args_; + + private: + llvm::LLVMContext& getContext(); + llvm::Type* dtypeToLLVM(Dtype dtype); + llvm::Type* dtypeToLLVMPtr(Dtype dtype); + void emitWrapper(const std::vector& params); + void emitKernel(Stmt* stmt, const std::vector& params); + + public: + explicit LLVMCodeGen( + Stmt* stmt, + const std::vector& args, + Dtype dtype = kInt); + explicit LLVMCodeGen(Stmt* stmt); + + ~LLVMCodeGen() override {} + + TORCH_API void call(const std::vector& args) override; + + void visit(const Add* v) override; + void visit(const Sub* v) override; + void visit(const Mul* v) override; + void visit(const Div* v) override; + void visit(const Mod* v) override; + void visit(const Max* v) override; + void visit(const Min* v) override; + void visit(const And* v) override; + void visit(const Xor* v) override; + void visit(const Lshift* v) override; + void visit(const Rshift* v) override; + void visit(const CompareSelect* v) override; + +#define IMM_VISIT_DECLARE(_1, Name) void visit(const Name##Imm* v) override; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, IMM_VISIT_DECLARE); +#undef IMM_VISIT_DECLARE + + void visit(const Cast* v) override; + void visit(const Var* v) override; + void visit(const Let* v) override; + void visit(const LetStmt* v) override; + void visit(const Ramp* v) override; + void visit(const Load* v) override; + void visit(const For* v) override; + void visit(const Block* v) override; + void visit(const Store* v) override; + void visit(const Broadcast* v) override; + void visit(const IfThenElse* v) override; + void visit(const BaseCallNode* v) override; + void visit(const Intrinsics* v) override; + void visit(const FunctionCall* v) override; + void visit(const CallExternal* v) override; + void visit(const OpaqueCall* v) override; + void visit(const Allocate* v) override; + void visit(const Free* v) override; + void visit(const Cond* v) override; + + llvm::Value* emitUnmaskedLoad(llvm::Value* addr, llvm::Value* idx); + llvm::Value* emitMaskedLoad( + llvm::Value* addr, + llvm::Value* idx, + llvm::Value* mask); + void emitUnmaskedStore(llvm::Value* base, llvm::Value* idx, llvm::Value* val); + void emitMaskedStore( + llvm::Value* base, + llvm::Value* idx, + llvm::Value* mask, + llvm::Value* val); + + void optimize(llvm::Module& M); + + template + T value() { + std::vector args; + return value(args); + } + + template + T value(std::vector& args) { + T (*fp)(void**) = (T(*)(void**))kernelAddress_; + T rv = fp(args.data()); + return rv; + } +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +#endif // ENABLE_LLVM diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.cpp b/torch/csrc/jit/tensorexpr/llvm_jit.cpp new file mode 100644 index 0000000000000..3392c2a47ef1d --- /dev/null +++ b/torch/csrc/jit/tensorexpr/llvm_jit.cpp @@ -0,0 +1,133 @@ +#ifdef ENABLE_LLVM + +#include "torch/csrc/jit/tensorexpr/llvm_jit.h" +#include "torch/csrc/jit/tensorexpr/native.h" + +#include +#include +#include +#include +#include "llvm/ExecutionEngine/Orc/LLJIT.h" + +namespace llvm { +namespace orc { + +// Lightly modified implementation from LLVM's Kaleidoscope JIT tutorial: +// https://llvm.org/docs/tutorial/BuildingAJIT1.html +class TORCH_API PytorchLLVMJITImpl { + private: + std::unique_ptr LLJ; + MangleAndInterner Mangle; + + public: + PytorchLLVMJITImpl() + : LLJ(cantFail(LLJITBuilder().create())), + Mangle(LLJ->getExecutionSession(), LLJ->getDataLayout()) { + auto ProcSymbolsGenerator = + cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( + LLJ->getDataLayout().getGlobalPrefix())); + LLJ->getMainJITDylib().setGenerator(std::move(ProcSymbolsGenerator)); + // Handle platform-specific symbol mangling + + for (auto kv : getNativeFunctionRegistry()) { + auto str = kv.first; + auto func = kv.second.first; + cantFail(LLJ->defineAbsolute( + mangle(str), {llvm::pointerToJITTargetAddress(func), {}})); + } + + // Register implementations of intrinsics + cantFail(LLJ->defineAbsolute( + mangle("log10f"), {llvm::pointerToJITTargetAddress(&log10f), {}})); + cantFail(LLJ->defineAbsolute( + mangle("logf"), {llvm::pointerToJITTargetAddress(&logf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("log2f"), {llvm::pointerToJITTargetAddress(&log2f), {}})); + cantFail(LLJ->defineAbsolute( + mangle("expf"), {llvm::pointerToJITTargetAddress(&expf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("erff"), {llvm::pointerToJITTargetAddress(&erff), {}})); + cantFail(LLJ->defineAbsolute( + mangle("cosf"), {llvm::pointerToJITTargetAddress(&cosf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("sinf"), {llvm::pointerToJITTargetAddress(&sinf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("tanf"), {llvm::pointerToJITTargetAddress(&tanf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("acosf"), {llvm::pointerToJITTargetAddress(&acosf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("asinf"), {llvm::pointerToJITTargetAddress(&asinf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("atanf"), {llvm::pointerToJITTargetAddress(&atanf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("coshf"), {llvm::pointerToJITTargetAddress(&coshf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("sinhf"), {llvm::pointerToJITTargetAddress(&sinhf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("tanhf"), {llvm::pointerToJITTargetAddress(&tanhf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("sqrtf"), {llvm::pointerToJITTargetAddress(&sqrtf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("fabsf"), {llvm::pointerToJITTargetAddress(&fabsf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("floorf"), {llvm::pointerToJITTargetAddress(&floorf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("ceilf"), {llvm::pointerToJITTargetAddress(&ceilf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("roundf"), {llvm::pointerToJITTargetAddress(&roundf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("truncf"), {llvm::pointerToJITTargetAddress(&truncf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("atan2f"), {llvm::pointerToJITTargetAddress(&atan2f), {}})); + cantFail(LLJ->defineAbsolute( + mangle("fmodf"), {llvm::pointerToJITTargetAddress(&fmodf), {}})); + cantFail(LLJ->defineAbsolute( + mangle("remainderf"), + {llvm::pointerToJITTargetAddress(&remainderf), {}})); + } + + Error addModule(ThreadSafeModule M) { + if (auto Err = LLJ->addIRModule(std::move(M))) { + return Err; + } + return Error::success(); + } + + JITSymbol findSymbol(const std::string Name) { + return cantFail(LLJ->lookup(Name)); + } + + StringRef mangle(std::string S) { + return *Mangle(S); + } + + const DataLayout& getDataLayout() { + return LLJ->getDataLayout(); + } +}; + +PytorchLLVMJIT::PytorchLLVMJIT() + : impl_(std::make_unique()) {} + +PytorchLLVMJIT::~PytorchLLVMJIT() = default; + +Error PytorchLLVMJIT::addModule(ThreadSafeModule M) { + return impl_->addModule(std::move(M)); +} + +JITSymbol PytorchLLVMJIT::findSymbol(const std::string Name) { + return impl_->findSymbol(std::move(Name)); +} + +StringRef PytorchLLVMJIT::mangle(std::string S) { + return impl_->mangle(S); +} + +const DataLayout& PytorchLLVMJIT::getDataLayout() { + return impl_->getDataLayout(); +} + +} // end namespace orc +} // end namespace llvm + +#endif // ENABLE_LLVM diff --git a/torch/csrc/jit/tensorexpr/llvm_jit.h b/torch/csrc/jit/tensorexpr/llvm_jit.h new file mode 100644 index 0000000000000..bc9fae3f49df2 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/llvm_jit.h @@ -0,0 +1,40 @@ +#pragma once + +#ifdef ENABLE_LLVM +#include + +#include "llvm/ExecutionEngine/JITSymbol.h" +#include "llvm/ExecutionEngine/Orc/Core.h" +#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h" +#include "llvm/Target/TargetMachine.h" + +#include +#include + +namespace llvm { +namespace orc { + +class PytorchLLVMJITImpl; + +class TORCH_API PytorchLLVMJIT { + public: + PytorchLLVMJIT(); + ~PytorchLLVMJIT(); + + Error addModule(ThreadSafeModule M); + StringRef mangle(std::string S); + + JITSymbol findSymbol(const std::string Name); + + TargetMachine& getTargetMachine(); + const DataLayout& getDataLayout(); + + private: + // Use PImpl idiom here to hide the no-rtti parts of the JIT structure. + std::unique_ptr impl_; +}; + +} // end namespace orc +} // end namespace llvm + +#endif // ENABLE LLVM diff --git a/torch/csrc/jit/tensorexpr/mem_arena.cpp b/torch/csrc/jit/tensorexpr/mem_arena.cpp new file mode 100644 index 0000000000000..c011c659306a7 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_arena.cpp @@ -0,0 +1,56 @@ +#include "torch/csrc/jit/tensorexpr/mem_arena.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +namespace { +// Define in an anonymous namespace to hide this symbol from other compilation +// units +thread_local KernelArena* current_arena = nullptr; +} + +KernelArena::~KernelArena() { + for (KernelScopedObject* p : kernel_objects_) { + delete p; + } +} + +KernelScopedObject::KernelScopedObject() { + KernelArena* kernel = KernelArena::GetCurrentKernelArena(); + kernel->kernel_objects_.push_back(this); +} + +static std::vector& GetKernelArenaStack() { + thread_local std::vector kernel_arena_stack; + return kernel_arena_stack; +} + +void KernelArena::SetCurrentKernelArena(KernelArena *new_kernel_arena) { + current_arena = new_kernel_arena; +} + +KernelArena* KernelArena::GetCurrentKernelArena() { + return current_arena; +} + +KernelScope::KernelScope() : owning_(true) { + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); + KernelArena::SetCurrentKernelArena(new KernelArena); +} + +KernelScope::KernelScope(KernelArena* arena_) : owning_(false) { + old_kernel_arena_ = KernelArena::GetCurrentKernelArena(); + KernelArena::SetCurrentKernelArena(arena_); +} + +KernelScope::~KernelScope() { + if (owning_) { + delete KernelArena::GetCurrentKernelArena(); + } + KernelArena::SetCurrentKernelArena(old_kernel_arena_); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/mem_arena.h b/torch/csrc/jit/tensorexpr/mem_arena.h new file mode 100644 index 0000000000000..121bdb60e02ae --- /dev/null +++ b/torch/csrc/jit/tensorexpr/mem_arena.h @@ -0,0 +1,61 @@ +#pragma once +#include +#include "torch/csrc/WindowsTorchApiMacro.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +class KernelScopedObject; + +// An arena that manages all the underlying kernel-scoped objects. +class KernelArena { + public: + static KernelArena* GetCurrentKernelArena(); + static void SetCurrentKernelArena(KernelArena* new_arena); + TORCH_API KernelArena() {} + TORCH_API ~KernelArena(); + + private: + KernelArena(const KernelArena&) = delete; + KernelArena& operator=(const KernelArena&) = delete; + friend class KernelScopedObject; + std::vector kernel_objects_; // owned +}; + +// A RAII convenience wrapper on top of a kernel. +// It either creates or takes an existing Kernel and sets it as the current +// Kernel. When this object is destroyed, the previous Kernel is set as current, +// and the created kernel is freed. If the kernel was passed, it stays alive. +class KernelScope { + public: + TORCH_API KernelScope(); + TORCH_API explicit KernelScope(KernelArena* arena_); + TORCH_API ~KernelScope(); + + private: + KernelScope(const KernelScope&) = delete; + KernelScope& operator=(const KernelScope&) = delete; + KernelArena* kernel_arena_ = nullptr; // arena to be used in this scope + KernelArena* old_kernel_arena_ = + nullptr; // previous arena, will be restored in destructor + bool owning_ = false; // determines whether the arena will be freed along with + // the scope object +}; + +// The base object managed by the Kernel. +// The object must be created through "new", and when the Kernel is destroyed, +// All its registered objects are destroyed through "delete". +class TORCH_API KernelScopedObject { + public: + KernelScopedObject(); + virtual ~KernelScopedObject() = default; + + private: + KernelScopedObject(const KernelScopedObject&) = delete; + KernelScopedObject& operator=(const KernelScopedObject&) = delete; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/native.cpp b/torch/csrc/jit/tensorexpr/native.cpp new file mode 100644 index 0000000000000..ecf2110b965b1 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/native.cpp @@ -0,0 +1,54 @@ +#ifdef ENABLE_LLVM + +#include "torch/csrc/jit/tensorexpr/native.h" +#include +#include "ATen/NativeFunctions.h" + +std::unordered_map>& +getNativeFunctionRegistry() { + static std::unordered_map> nfr_; + return nfr_; +} + +void matmul(float* a, float* b, size_t N, size_t M, size_t K, float* c) { + for (auto i = 0; i < N * M; ++i) { + c[i] = 0; + } + + for (auto j = 0; j < N; ++j) { + for (auto i = 0; i < M; ++i) { + for (auto k = 0; k < K; ++k) { + c[j * M + i] += a[j * K + k] * b[k * M + i]; + } + } + } +} + +using namespace torch::jit::tensorexpr; + +static RegisterNativeFunction f( + "aten::matmul", + &matmul, + [](TensorExprKernel* tek, const torch::jit::Value* v) { + return Compute( + "aten_matmul", + texprDims(v), + [tek, v](const std::vector& axes) -> ExprHandle { + const torch::jit::Node* n = v->node(); + TORCH_CHECK(n->inputs().size() == 2); + + tek->addNoInline(n->inputs()[0]->unique()); + tek->addNoInline(n->inputs()[1]->unique()); + // TODO This is totally broken + const Expr* e0 = tek->tensorOrConstant(n->inputs()[0], axes).node(); + auto t0 = tek->getTensor(n->inputs()[0]->unique())->function(); + const Expr* e1 = tek->tensorOrConstant(n->inputs()[1], axes).node(); + auto t1 = tek->getTensor(n->inputs()[1]->unique())->function(); + // N, M, K + std::vector inputs = { + e0, e1, t0->dim(0), t1->dim(1), t0->dim(1)}; + return ExprHandle(CallExternal::make("aten::matmul", inputs)); + }); + }); + +#endif // ENABLE_LLVM diff --git a/torch/csrc/jit/tensorexpr/native.h b/torch/csrc/jit/tensorexpr/native.h new file mode 100644 index 0000000000000..636cbf49de5e7 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/native.h @@ -0,0 +1,32 @@ +#ifdef ENABLE_LLVM +#pragma once + +//#include +#include +#include + +namespace torch { +namespace jit { +class Value; +namespace tensorexpr { +class Tensor; +class TensorExprKernel; +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +using TensorCreator = std::function; +std::unordered_map>& +getNativeFunctionRegistry(); + +struct RegisterNativeFunction { + template + RegisterNativeFunction(std::string name, T* fn, TensorCreator cv) { + getNativeFunctionRegistry()[name] = + std::make_pair(reinterpret_cast(fn), cv); + } +}; + +#endif // ENABLE_LLVM diff --git a/torch/csrc/jit/tensorexpr/schedule.cpp b/torch/csrc/jit/tensorexpr/schedule.cpp new file mode 100644 index 0000000000000..7a91264247b1c --- /dev/null +++ b/torch/csrc/jit/tensorexpr/schedule.cpp @@ -0,0 +1,890 @@ +#include "torch/csrc/jit/tensorexpr/schedule.h" + +#include +#include +#include +#include +#include + +#include "torch/csrc/jit/tensorexpr/eval.h" +#include "torch/csrc/jit/tensorexpr/ir_mutator.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace schedule { + +namespace { + +// Evaluates a constant expression and returns its value. +template +static T EvalConstExpr(const ExprHandle& expr) { + ExprEval eval(expr); + return eval.value(); +} + +} // namespace + +ScheduleNode::~ScheduleNode() { + for (ScheduleObject* p : schedule_objects_) { + delete p; + } +} + +class ScheduleNode::DependencyTracker : public IRVisitor { + public: + virtual ~DependencyTracker() = default; + DependencyTracker(const std::vector& output_tensors) { + for (size_t i = 0; i < output_tensors.size(); i++) { + const Tensor* node = output_tensors[i]; + to_process_.push(node); + encountered_.insert(node); + given_tensors_.insert(node); + } + + // Extract all the consumer-producer relationship. + while (!to_process_.empty()) { + Tensor* tensor = const_cast(to_process_.front()); + to_process_.pop(); + current_consumer_ = tensor; + tensor->body()->accept(this); + } + + // Topologically sorted all the tensors in encountered_ + while (!encountered_.empty()) { + sort_tensor_node(*encountered_.begin()); + } + } + + std::vector GetTopologicallySorted() const { + return topologically_sorted_; + } + + bool is_internal(const Tensor* tensor_node) const { + return (given_tensors_.count(tensor_node) == 0); + } + + private: + void visit(const FunctionCall* v) override { + const Tensor* producer = v->tensor(); + add_producer_consumer_pair(current_consumer_, producer); + } + + void add_producer_consumer_pair( + const Tensor* consumer, + const Tensor* producer) { + producers_[consumer].insert(producer); + consumers_[producer].insert(consumer); + if (encountered_.count(producer) == 0) { + encountered_.insert(producer); + to_process_.push(producer); + } + } + + // topoligically sort the sub tensors under the current node + void sort_tensor_node(const Tensor* tensor_node) { + encountered_.erase(tensor_node); + auto iter = producers_.find(tensor_node); + if (iter != producers_.end()) { + for (const Tensor* producer_node : iter->second) { + if (encountered_.count(producer_node) != 0) { + sort_tensor_node(producer_node); + } + } + } + topologically_sorted_.push_back(tensor_node); + } + + std::unordered_map> + producers_; + std::unordered_map> + consumers_; + + // the tensors given in the constructors. They are either the input or the + // output of the entire schedule. + std::unordered_set given_tensors_; + + const Tensor* current_consumer_ = nullptr; + std::unordered_set encountered_; + std::queue to_process_; + std::vector topologically_sorted_; +}; + +ScheduleNode::ScheduleNode(const std::vector& tensors) + : output_tensors_(tensors) { + dependency_tracker_.reset(new DependencyTracker(tensors)); + root_node_ = this->NewTensorExprNode(); + TensorExprNode* current_func = nullptr; + std::vector sorted_tensors = + dependency_tracker_->GetTopologicallySorted(); + for (const Tensor* tensor_node : sorted_tensors) { + Function* func = tensor_node->function(); + if (current_func == nullptr) { + current_func = root_node_->NewFirstChild(); + } else { + current_func = current_func->NewNextSibling(); + } + // TODO: handles the scalar case where ndims == 0 + TensorExprNode* expr_node = current_func; + if (dynamic_cast(tensor_node->body()) == nullptr) { + for (int i = 0; i < func->ndim(); i++) { + expr_node = expr_node->NewFirstChild(); + LoopAxis* loop_axis = this->NewAxis( + VarHandle(func->arg(i)), Range(0, ExprHandle(func->dim(i)))); + expr_node->set_loop_axis(loop_axis); + } + expr_node = expr_node->NewFirstChild(); + } + TensorExprOp* tensor_expr_op = this->NewTensorExprOp(func); + expr_node->set_tensor_expr_op(tensor_expr_op); + + // attach the node to the user provided tensors. + Tensor* tensor_mutable = const_cast(tensor_node); + tensor_mutable->expr_node_ = expr_node; + + if (dependency_tracker_->is_internal(tensor_node)) { + internal_tensors_.push_back(const_cast(tensor_node)); + } + } +} + +void ScheduleNode::ComputeInline(TensorExprNode* expr_node) { + if (!expr_node->is_tensor_expr_op()) { + throw std::runtime_error("expr_node must be tensor_expr_op"); + } + + TensorExprOp* texpr_op = expr_node->tensor_expr_op(); + inlined_functions_.push_back(texpr_op->func()); +} + +void ScheduleNode::GPUExecConfig( + TensorExprNode* expr_node, + const std::vector& blockIdx, + const std::vector& threadIdx) { + // Extract all the ancestors into a var* to loop-axis lookup table + std::unordered_map var_to_loop; + TensorExprNode* node = expr_node; + while (node != nullptr) { + if (node->is_loop_axis()) { + LoopAxis* loop_axis = node->loop_axis(); + const VarHandle& loop_var = loop_axis->var(); + var_to_loop[loop_var.node()] = loop_axis; + } + node = node->parent(); + } + + // Set the blockIndex attr. + for (int i = 0; i < blockIdx.size(); i++) { + auto iter = var_to_loop.find(blockIdx[i].node()); + if (iter == var_to_loop.end()) { + throw std::runtime_error( + "Invalid blockIdx: " + std::to_string(i) + ", " + + blockIdx[i].name_hint()); + } + iter->second->set_gpu_block_index(i); + } + + // Set the threadIdx attr. + for (int i = 0; i < threadIdx.size(); i++) { + auto iter = var_to_loop.find(threadIdx[i].node()); + if (iter == var_to_loop.end()) { + throw std::runtime_error( + "Invalid threadIdx: " + std::to_string(i) + ", " + + threadIdx[i].name_hint()); + } + iter->second->set_gpu_thread_index(i); + } +} + +void ScheduleNode::SplitWithTail( + TensorExprNode* expr_node, + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var, + VarHandle* tail_var, + TensorExprNode** tail_op) { + // find the loop_axis that contains loop_var in the ancestor + TensorExprNode* loop_node = expr_node; + while (loop_node != nullptr) { + if (loop_node->is_loop_axis()) { + LoopAxis* loop_axis = loop_node->loop_axis(); + if (loop_axis->var() == loop_var) { + break; + } + } + loop_node = loop_node->parent(); + } + + if (loop_node == nullptr) { + // TODO: change to a recoverable error. + LOG(FATAL) << "loop var cannot be found in the ancestors of node"; + } + + // create the new loop_axis + SplitAxisWithTail* split_transform = this->NewSplitAxisWithTail( + loop_node->loop_axis(), factor, factor_on_inner); + CHECK(split_transform->output_group_count() >= 1); + CHECK(split_transform->output_group_size(0) == 2); + LoopAxis* outer_axis = split_transform->output(0, 0); + LoopAxis* inner_axis = split_transform->output(0, 1); + LoopAxis* tail_axis = nullptr; + if (split_transform->output_group_count() >= 2) { + tail_axis = split_transform->output(1, 0); + } + + // replace loop_node with the new loop_axis + TensorExprNode* outer_node = this->NewTensorExprNode(); + outer_node->set_loop_axis(outer_axis); + *outer_var = outer_axis->var(); + TensorExprNode* inner_node = outer_node->NewFirstChild(); + inner_node->set_loop_axis(inner_axis); + *inner_var = inner_axis->var(); + TensorExprNode* loop_sibling = loop_node->next_sibling(); + TensorExprNode* loop_child = loop_node->first_child(); + inner_node->SetFirstChild(loop_child); + if (tail_axis != nullptr) { + TensorExprNode* tail_node = outer_node->NewNextSibling(); + tail_node->set_loop_axis(tail_axis); + TensorExprNode* loop_child_clone = nullptr; + { + ScopedCloneMap clone_map_scope(this); + loop_child_clone = CloneObject(loop_child); + CloneMap& clone_map = clone_map_scope.clone_map(); + CloneMap::iterator iter = clone_map.find(expr_node); + if (iter == clone_map.end()) { + LOG(FATAL) << "cannot find node in the clone-map"; + } + TensorExprNode* expr_node_clone = + dynamic_cast(iter->second); + CHECK(!expr_node || expr_node_clone) + << "expr_node is not null, but its clone is"; + *tail_op = expr_node_clone; + DCHECK(expr_node_clone->is_tensor_expr_op()); + expr_node_clone->tensor_expr_op()->ApplyLoopTransform(split_transform, 1); + } + tail_node->SetFirstChild(loop_child_clone); + tail_node->SetNextSibling(loop_sibling); + *tail_var = tail_axis->var(); + } else { + outer_node->SetNextSibling(loop_sibling); + } + CHECK(expr_node->is_tensor_expr_op()); + // This transform is left after the tail axis is cloned, so it doesn't affect + // the tail axis. + expr_node->tensor_expr_op()->ApplyLoopTransform(split_transform, 0); + TensorExprNode::ReplaceSubtree(loop_node, outer_node); +} + +// TODO: Merge with SplitWithTail +void ScheduleNode::SplitWithMask( + TensorExprNode* expr_node, + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var) { + // find the loop_axis that contains loop_var in the ancestor + TensorExprNode* loop_node = expr_node; + while (loop_node != nullptr) { + if (loop_node->is_loop_axis()) { + LoopAxis* loop_axis = loop_node->loop_axis(); + if (loop_axis->var() == loop_var) { + break; + } + } + loop_node = loop_node->parent(); + } + + if (loop_node == nullptr) { + // TODO: change to a recoverable error. + LOG(FATAL) << "loop var cannot be found in the ancestors of node"; + } + + // create the new loop_axis + SplitAxisWithMask* split_transform = this->NewSplitAxisWithMask( + loop_node->loop_axis(), factor, factor_on_inner); + CHECK(split_transform->output_group_count() == 1); + CHECK(split_transform->output_group_size(0) == 2); + LoopAxis* outer_axis = split_transform->output(0, 0); + LoopAxis* inner_axis = split_transform->output(0, 1); + + // replace loop_node with the new loop_axis + TensorExprNode* outer_node = this->NewTensorExprNode(); + outer_node->set_loop_axis(outer_axis); + *outer_var = outer_axis->var(); + TensorExprNode* inner_node = outer_node->NewFirstChild(); + inner_node->set_loop_axis(inner_axis); + *inner_var = inner_axis->var(); + TensorExprNode* loop_sibling = loop_node->next_sibling(); + TensorExprNode* loop_child = loop_node->first_child(); + inner_node->SetFirstChild(loop_child); + outer_node->SetNextSibling(loop_sibling); + + CHECK(expr_node->is_tensor_expr_op()); + expr_node->tensor_expr_op()->AddPredicate( + split_transform->predicate().node()); + expr_node->tensor_expr_op()->ApplyLoopTransform(split_transform, 0); + TensorExprNode::ReplaceSubtree(loop_node, outer_node); +} + +void TensorExprNode::SetParent(TensorExprNode* parent) { + TensorExprNode* n = this; + while (n != nullptr) { + n->parent_ = parent; + n = n->next_sibling(); + } +} + +void TensorExprNode::SetNextSibling(TensorExprNode* node) { + TensorExprNode* old_sibling = this->next_sibling_; + this->next_sibling_ = node; + // reset all the parent links for the siblings + if (node) { + node->SetParent(this->parent()); + } + // detach the parents in the previous next_sibling to prevent dangling + // pointers. + if (old_sibling) { + old_sibling->SetParent(nullptr); + } +} + +void TensorExprNode::SetFirstChild(TensorExprNode* node) { + TensorExprNode* old_child = this->first_child_; + this->first_child_ = node; + // reset all the parent links + if (node) { + node->SetParent(this); + } + if (old_child) { + old_child->SetParent(nullptr); + } +} + +void ScheduleObject::AddClonePair(ScheduleObject* new_obj) { + ScheduleNode* schedule = this->schedule(); + schedule->clone_map().insert(std::make_pair(this, new_obj)); +} + +ScheduleObject* ScheduleNode::CloneScheduleObject(ScheduleObject* object) { + if (object == nullptr) + return nullptr; + + bool map_initialized = false; + if (!clone_map_) { + map_initialized = true; + clone_map_.reset(new CloneMap()); + } + + CloneMap::iterator iter = clone_map_->find(object); + if (iter != clone_map_->end()) { + return iter->second; + } + + ScheduleObject* new_object = object->Clone(); + // TODO: Clone may have inseretd into the map. Only one insertion is needed. + clone_map_->insert(std::make_pair(object, new_object)); + + if (map_initialized) { + clone_map_.reset(); + } + + return new_object; +} + +class Flattener : public IRMutator { + private: + Expr* mutate(const FunctionCall* v) override { + const Tensor* t = v->tensor(); + Buffer buffer( + VarHandle(t->func_var()), + t->body()->dtype(), + ExprVectorToExprHandleVector(t->dims())); + const std::vector& params = v->params(); + std::vector params_expr(params.size()); + for (size_t i = 0; i < params.size(); i++) { + params_expr[i] = ExprHandle(params[i]); + } + return buffer(params_expr).node(); + } +}; + +class FunctionInliner : public IRMutator { + public: + FunctionInliner(const std::vector& funcs) : funcs_(funcs) { + for (Function* func : funcs) { + // TODO: Support multiple-output functions + CHECK(func->func_vars().size() == 1); + func_var_set_.insert(func->func_var(0)); + } + } + + private: + // For the target function, insert the caller/callee pair into the replacement + // mapping. + const Expr* mutate(const FunctionCall* v) override { + Function* func = v->tensor()->function(); + // TODO: Support multiple-output functions + CHECK(func->func_vars().size() == 1); + if (func_var_set_.count(func->func_var(0)) > 0) { + // Insert the caller/callee pair into the mapping. + for (int i = 0; i < func->ndim(); i++) { + const Var* func_callee_arg = dynamic_cast(func->arg(i)); + const Expr* func_caller_param = v->param(i); + auto iter = inline_mapping_.find(func_callee_arg); + if (iter != inline_mapping_.end()) { + throw std::runtime_error( + "Duplicated variables: " + func_callee_arg->name_hint()); + } + inline_mapping_[func_callee_arg] = func_caller_param; + } + + // Call the actual replacement. + const Expr* body = func->body(v->tensor()->output_index()); + const Expr* result = body->accept_mutator(this); + + // Remove the caller/callee relationship. + for (int i = 0; i < func->ndim(); i++) { + const Var* func_callee_arg = dynamic_cast(func->arg(i)); + auto iter = inline_mapping_.find(func_callee_arg); + if (iter == inline_mapping_.end()) { + throw std::runtime_error( + "Var already removed: " + func_callee_arg->name_hint()); + } + inline_mapping_.erase(iter); + } + return result; + } else { + return IRMutator::mutate(v); + } + } + + // Replace the target variable with the caller expressions. + const Expr* mutate(const Var* v) { + auto iter = inline_mapping_.find(v); + if (iter == inline_mapping_.end()) { + return IRMutator::mutate(v); + } else { + const Expr* expr = iter->second; + // Continue to transform the value from the lookup table. + return expr->accept_mutator(this); + } + } + + // Remove the buffer write the inlined function. + Stmt* mutate(const Store* v) override { + if (func_var_set_.count(v->base_handle()) > 0) { + return nullptr; + } else { + return IRMutator::mutate(v); + } + } + + std::unordered_map inline_mapping_; + std::vector funcs_; + std::unordered_set func_var_set_; +}; + +static Stmt* InjectInlines( + Stmt* stmt, + const std::vector& inlined_funcs) { + FunctionInliner inliner(inlined_funcs); + Stmt* stmt_old = stmt; + Stmt* stmt_new = stmt_old->accept_mutator(&inliner); + return stmt_new; +} + +ScheduleObject* ScheduleNode::LookUpCloneScheduleObject( + ScheduleObject* object) { + if (object == nullptr) { + return nullptr; + } + if (!clone_map_) { + return nullptr; + } + + CloneMap::iterator iter = clone_map_->find(object); + if (iter == clone_map_->end()) { + return nullptr; + } + + return iter->second; +} + +// TODO: change to a stack-based version without recursion +Stmt* ScheduleNode::Lower(TensorExprNode* node) { + if (node == nullptr) { + return nullptr; + } + if (node->next_sibling() != nullptr) { + std::vector siblings; + TensorExprNode* n = node; + while (n != nullptr) { + Stmt* stmt = LowerNoSibling(n); + siblings.push_back(stmt); + n = n->next_sibling(); + } + return Block::make(siblings); + } + return LowerNoSibling(node); +} + +Stmt* ScheduleNode::Lower() { + Stmt* core_stmt = Lower(root_node_); + + // Inject inlines + core_stmt = InjectInlines(core_stmt, inlined_functions_); + + // Flatten function calls. + Flattener flattener; + core_stmt = core_stmt->accept_mutator(&flattener); + + // Add allocs and frees for intermediate buffers at the global level. + // TODO: move allocs and frees to the imemediate areas to reuse buffers. + if (internal_tensors_.size() == 0ULL) { + return core_stmt; + } + + std::unordered_set inlined_func_set; + for (size_t i = 0; i < inlined_functions_.size(); i++) { + inlined_func_set.insert(inlined_functions_[i]); + } + std::unordered_set output_tensors_set; + for (size_t i = 0; i < output_tensors_.size(); i++) { + output_tensors_set.insert(output_tensors_[i]); + } + std::vector allocs; + std::vector frees; + for (size_t i = 0; i < internal_tensors_.size(); i++) { + Tensor* tensor = internal_tensors_[i]; + if (inlined_func_set.count(tensor->function()) > 0) { + // No need to allocation memory for intermediate tensors. + continue; + } + if (output_tensors_set.count(tensor) > 0) { + // No need to allocate memory if the tensors are given as input/output. + continue; + } + Stmt* alloc = new Allocate( + tensor->func_var(), tensor->body()->dtype(), tensor->dims()); + allocs.push_back(alloc); + Stmt* free = new Free(tensor->func_var()); + frees.push_back(free); + } + std::reverse(frees.begin(), frees.end()); + Stmt* alloc_block = Block::make(allocs); + Stmt* free_block = Block::make(frees); + Stmt* combined_stmt = Block::make({alloc_block, core_stmt, free_block}); + return combined_stmt; +} + +Stmt* ScheduleNode::LowerNoSibling(TensorExprNode* node) { + if (node == nullptr) { + return nullptr; + } + if (node->is_empty_value()) { + return Lower(node->first_child()); + } + if (node->is_tensor_expr_op()) { + CHECK(node->first_child() == nullptr); + TensorExprOp* expr_op = node->tensor_expr_op(); + Stmt* stmt = expr_op->ElementStmt(); + // TODO: the predicate should be hoisted to as high as possible in the + // acestor chain. + const std::vector& predicates = expr_op->predicates(); + for (int i = 0; i < predicates.size(); i++) { + stmt = Cond::make(predicates[i], stmt, nullptr); + } + return stmt; + } else if (node->is_loop_axis()) { + CHECK(node->first_child() != nullptr); + LoopAxis* loop_axis = node->loop_axis(); + Stmt* body = Lower(node->first_child()); + const VarHandle& var = loop_axis->var(); + const Range& range = loop_axis->range(); + Stmt* for_stmt = For::make( + var, range.start(), range.stop(), body, loop_axis->loop_options()); + return for_stmt; + } else if (node->is_empty_value()) { + return Lower(node->first_child()); + } else { + LOG(FATAL) << "Unsupported node type"; + return nullptr; + } +} + +void LoopAxis::CloneFrom(const LoopAxis* other) { + this->loop_var_ = other->loop_var_; + this->loop_range_ = other->loop_range_; + this->axis_type_ = other->axis_type_; + this->is_leaf_ = other->is_leaf_; + this->output_group_index_ = other->output_group_index_; + this->loop_options_ = other->loop_options_; + + this->loop_axis_transform_ = CloneObject(other->loop_axis_transform_); +} + +void LoopAxisTransform::CloneFrom(const LoopAxisTransform* other) { + inputs_.resize(other->inputs_.size()); + outputs_.resize(other->outputs_.size()); + + for (size_t i = 0; i < inputs_.size(); i++) { + inputs_[i] = CloneObject(other->inputs_[i]); + } + for (size_t i = 0; i < outputs_.size(); i++) { + std::vector& output = outputs_[i]; + const std::vector& other_output = other->outputs_[i]; + output.resize(other_output.size()); + for (size_t j = 0; j < other_output.size(); j++) { + output[j] = CloneObject(other_output[j]); + } + } +} + +void SplitAxisTransform::CloneFrom(const SplitAxisTransform* other) { + this->LoopAxisTransform::CloneFrom(other); + this->factor_on_inner_ = other->factor_on_inner_; + this->factor_ = other->factor_; + this->start_ = other->start_; + this->stop_ = other->stop_; +} + +void SplitAxisWithTail::CloneFrom(const SplitAxisWithTail* other) { + this->SplitAxisTransform::CloneFrom(other); +} + +void SplitAxisWithMask::CloneFrom(const SplitAxisWithMask* other) { + this->SplitAxisTransform::CloneFrom(other); +} + +void TensorExprNode::CloneFrom(const TensorExprNode* other) { + this->next_sibling_ = CloneObject(other->next_sibling_); + this->first_child_ = CloneObject(other->first_child_); + this->node_value_.CloneFrom(&other->node_value_); + + // the parent_ link is valid at this point, since it was updated within + // Cloneable when the parent object. If the parent link points outside what + // was cloned so far, it points to NULL. + this->parent_ = LookUpCloneObject(other->parent_); +} + +void TensorExprNode::NodeValue::CloneFrom( + const TensorExprNode::NodeValue* other) { + this->node_type = other->node_type; + if (this->node_type == NodeType::kOperation) { + this->tensor_expr_op = CloneObject(other->tensor_expr_op); + } else if (node_type == NodeType::kAxis) { + this->loop_axis = CloneObject(other->loop_axis); + } else if (node_type == NodeType::kEmptyValue) { + // no actdion taken + } else { + LOG(FATAL) << "Invalid node type: " << static_cast(this->node_type); + } +} + +void TensorExprNode::ReplaceSubtree( + TensorExprNode* old_node, + TensorExprNode* new_node) { + CHECK(old_node->parent() != nullptr) << "cannot replace a root node"; + + TensorExprNode* parent = old_node->parent_; + if (parent->first_child() == old_node) { + parent->SetFirstChild(new_node); + } else { + TensorExprNode* n = parent->first_child(); + while (n != nullptr && n->next_sibling() != new_node) { + n = n->next_sibling(); + } + if (n == nullptr) { + LOG(FATAL) << "Cannot find node as a child of its parent"; + } + n->SetNextSibling(new_node); + } +} + +TensorExprNode* TensorExprNode::NewNextSibling() { + DCHECK(next_sibling_ == nullptr); + TensorExprNode* sibling = schedule()->NewTensorExprNode(); + sibling->parent_ = this->parent_; + this->next_sibling_ = sibling; + return sibling; +} + +TensorExprNode* TensorExprNode::NewFirstChild() { + DCHECK(first_child_ == nullptr); + TensorExprNode* first_child = schedule()->NewTensorExprNode(); + first_child->parent_ = this; + this->first_child_ = first_child; + return first_child; +} + +SplitAxisTransform::SplitAxisTransform( + LoopAxis* loop_axis, + int factor, + bool factor_on_inner) + : BaseClass(std::vector({loop_axis})), + factor_(factor), + factor_on_inner_(factor_on_inner) { + const Range& loop_range = loop_axis->range(); + const ExprHandle& start_expr = loop_range.start(); + const ExprHandle& stop_expr = loop_range.stop(); + + start_ = start_expr; + stop_ = stop_expr; +} + +SplitAxisWithTail::SplitAxisWithTail( + LoopAxis* loop_axis, + int factor, + bool factor_on_inner) + : BaseClass(loop_axis, factor, factor_on_inner) { + // TODO: support factor_on_inner == false; + CHECK(factor_on_inner) << "only factor_on_inner = True is supported for now"; + + auto const& size = this->stop() - this->start(); + int output_group_count = 2; + if (this->stop().AsNode() && this->start().AsNode()) { + int startVal = this->start().AsNode()->value(); + int stopVal = this->stop().AsNode()->value(); + int sizeVal = stopVal - startVal; + int tail_size = sizeVal % factor; + if (tail_size == 0) { + output_group_count = 1; + } + } + auto const& split_count = size / factor; + auto const& tail_size = size % factor; + + this->set_output_group_count(output_group_count); + // The main group + const std::string& loop_var_name = loop_axis->var().name_hint(); + Dtype loop_var_dtype = loop_axis->var().dtype(); + LoopAxis* outer = this->NewAxis( + VarHandle(loop_var_name + "_outer", loop_var_dtype), + Range(0, split_count)); + LoopAxis* inner = this->NewAxis( + VarHandle(loop_var_name + "_inner", loop_var_dtype), Range(0, factor)); + this->set_output_group(0, {outer, inner}); + + // The tail group + if (output_group_count == 2) { + LoopAxis* tail = this->NewAxis( + VarHandle(loop_var_name + "_tail", loop_var_dtype), + Range(0, tail_size)); + this->set_output_group(1, {tail}); + } +} + +// TODO: merge with SplitAxisWithTail +SplitAxisWithMask::SplitAxisWithMask( + LoopAxis* loop_axis, + int factor, + bool factor_on_inner) + : BaseClass(loop_axis, factor, factor_on_inner) { + // TODO: support factor_on_inner == false; + CHECK(factor_on_inner) << "only factor_on_inner = True is supported for now"; + + // TODO: Support dynamic shapes + auto const& sizeExpr = this->stop() - this->start(); + bool needsPredicate = true; + if (this->stop().AsNode() && this->start().AsNode()) { + int size = + stop().AsNode()->value() - start().AsNode()->value(); + if ((size % factor) == 0) { + needsPredicate = false; + } + } + if (needsPredicate) { + IntImm* start = this->start().AsNode(); + CHECK(start && start->value() == 0) + << "Non-zero start is not implemented yet"; + predicate_ = CompareSelect::make(loop_axis->var(), this->stop(), kLT); + } + auto const& split_count = (sizeExpr + factor - 1) / factor; + + this->set_output_group_count(1); + const std::string& loop_var_name = loop_axis->var().name_hint(); + Dtype loop_var_dtype = loop_axis->var().dtype(); + LoopAxis* outer = this->NewAxis( + VarHandle(loop_var_name + "_outer", loop_var_dtype), + Range(0, split_count)); + LoopAxis* inner = this->NewAxis( + VarHandle(loop_var_name + "_inner", loop_var_dtype), Range(0, factor)); + this->set_output_group(0, {outer, inner}); +} + +ExprHandle SplitAxisWithTail::combined_loop_index(int output_group) { + LoopAxis* original_axis = this->input(0); + VarHandle original_var = original_axis->var(); + LoopAxis* outer = this->output(0, 0); + LoopAxis* inner = this->output(0, 1); + ExprHandle combined_index; + if (output_group == 0) { + // x -> x.outer * inner.size + x.inner + combined_index = outer->var() * inner->range().stop() + inner->var(); + } else if (output_group == 1) { + LoopAxis* tail = this->output(1, 0); + // x -> x.tail + outer.size * inner.size + combined_index = + tail->var() + outer->range().stop() * inner->range().stop(); + } else { + LOG(FATAL) << "invalid output_group: " << output_group; + } + return combined_index; +} + +Stmt* SplitAxisWithTail::ConvertToNewArgs(Stmt* stmt, int output_group) { + ExprHandle combined_index = combined_loop_index(output_group); + Stmt* new_stmt = Substitute(stmt, {{input(0)->var(), combined_index}}); + return new_stmt; +} + +ExprHandle SplitAxisWithTail::ConvertToNewArgs( + ExprHandle* expr, + int output_group) { + ExprHandle combined_index = combined_loop_index(output_group); + ExprHandle new_expr = Substitute(expr, {{input(0)->var(), combined_index}}); + return new_expr; +} + +ExprHandle SplitAxisWithMask::combined_loop_index(int output_group) { + DCHECK_EQ(output_group, 0) << "Ininvalid output group: " << output_group; + LoopAxis* original_axis = this->input(0); + VarHandle original_var = original_axis->var(); + LoopAxis* outer = this->output(0, 0); + LoopAxis* inner = this->output(0, 1); + ExprHandle combined_index = + outer->var() * inner->range().stop() + inner->var(); + return combined_index; +} + +Stmt* SplitAxisWithMask::ConvertToNewArgs(Stmt* stmt, int output_group) { + ExprHandle combined_index = combined_loop_index(output_group); + Stmt* new_stmt = Substitute(stmt, {{input(0)->var(), combined_index}}); + return new_stmt; +} + +ExprHandle SplitAxisWithMask::ConvertToNewArgs( + ExprHandle* expr, + int output_group) { + ExprHandle combined_index = combined_loop_index(output_group); + ExprHandle new_expr = Substitute(expr, {{input(0)->var(), combined_index}}); + return new_expr; +} + +LoopAxis* LoopAxisTransform::NewAxis( + const VarHandle& loop_var, + const Range& loop_range) { + ScheduleNode* schedule = this->schedule(); + LoopAxis* axis = schedule->NewAxis(loop_var, loop_range); + axis->set_loop_axis_transform(this); + return axis; +} + +} // namespace schedule +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/schedule.h b/torch/csrc/jit/tensorexpr/schedule.h new file mode 100644 index 0000000000000..408c0a683e3aa --- /dev/null +++ b/torch/csrc/jit/tensorexpr/schedule.h @@ -0,0 +1,677 @@ +#pragma once + +#include +#include + +#include +#include "torch/csrc/jit/tensorexpr/expr.h" +#include "torch/csrc/jit/tensorexpr/ir.h" +#include "torch/csrc/jit/tensorexpr/tensor.h" + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace schedule { + +// Schedule basics + +// An object owned by a schedule. Objects from subclasses should be created +// through Schedule +// method through "new", and only released with the Schedule is destroyed +// through "delete". +class ScheduleNode; +class ScheduleObject { + public: + ScheduleObject() {} + virtual ~ScheduleObject() {} + ScheduleNode* schedule() { + return schedule_; + } + + protected: + void AddClonePair(ScheduleObject* new_obj); + void set_schedule(ScheduleNode* schedule) { + schedule_ = schedule; + } + + private: + friend class ScheduleNode; + virtual ScheduleObject* Clone() = 0; + ScheduleObject(const ScheduleObject& other) = delete; + const ScheduleObject& operator=(const ScheduleObject& other) = delete; + + ScheduleNode* schedule_ = nullptr; // not owned +}; + +// A CRTP helper class to add Clone support for an object. +template +class Cloneable : public Base { + public: + // Forward the constructor to the underlying Base class + // Note that this does not work for implicit argument conversion. + // All arguments must be an exact match for their Base class counterpart. + template + explicit Cloneable(Args... args) : Base(std::forward(args)...) {} + + Cloneable(Cloneable&& other) = delete; + + private: + // The return type is set to ScheduleObject*. Otherwise, the compiler + // complains about covariant override. + ScheduleObject* Clone() override { + Object* new_object = this->schedule()->template NewObject(); + this->AddClonePair(new_object); + new_object->CloneFrom(static_cast(this)); + return new_object; + } +}; + +/// Loop Axis +class LoopAxisTransform; + +// A loop axis in the Tensor ExprHandle trees. +// Even if two loops are identical in shapes, the should have separate loop +// axis. In other words, loop axes should be be shared among differnt loops. +class TORCH_API LoopAxis : public Cloneable { + public: + enum AxisType { + kRegular, // a regular axis such as appeared in Compute + kReduction, // a redution axis + }; + + const VarHandle& var() const { + return loop_var_; + } + const Range& range() const { + return loop_range_; + } + AxisType axis_type() const { + return axis_type_; + } + const LoopAxisTransform* loop_axis_transform() const { + return loop_axis_transform_; + } + // Whether this axis is a source axis. + bool is_source() const { + return loop_axis_transform_ == nullptr; + } + // Whether this axis is a leaf axis. Only leaf axes can be used in other axis + // transformations. Internal axes are tracked for future computation, but + // logically they disappear from users' perspective. + bool is_leaf() const { + return true; + } + + void CloneFrom(const LoopAxis* other); + + const LoopOptions& loop_options() const { + return loop_options_; + } + + private: + friend class ScheduleNode; + friend class LoopAxisTransform; + + LoopAxis( + const VarHandle& loop_var, + const Range& loop_range, + AxisType axis_type, + LoopAxisTransform* transform) + : loop_var_(loop_var), + loop_range_(loop_range), + axis_type_(axis_type), + loop_axis_transform_(transform) {} + + LoopAxis() {} + + void mark_as_internal() { + is_leaf_ = false; + } + + void set_loop_axis_transform(LoopAxisTransform* transform) { + loop_axis_transform_ = transform; + } + + void set_output_group_index(int output_group_index) { + output_group_index_ = output_group_index; + } + + void set_gpu_block_index(int block_index) { + loop_options_.set_gpu_block_index(block_index); + } + + void set_gpu_thread_index(int thread_index) { + loop_options_.set_gpu_thread_index(thread_index); + } + + VarHandle loop_var_; + Range loop_range_; + AxisType axis_type_; + // TODO: check that only leaf axis can be used in axis tranforms. + bool is_leaf_ = true; + LoopAxisTransform* loop_axis_transform_ = nullptr; + int output_group_index_ = -1; + LoopOptions loop_options_; +}; + +// Loop Axis transformations +// Base class of loop axis transform. A number of input axes were taken, and +// several output groups are generated. Each output group is responsible for +// producing a subset within the input region. Note that each input axis can be +// used in at most one transform. +class TORCH_API LoopAxisTransform + : public Cloneable { + public: + LoopAxisTransform() {} + + // One Stmt for each output group + virtual Stmt* ConvertToNewArgs(Stmt* stmt, int group_index) { + LOG(FATAL) << "unmiplemented"; + return nullptr; + } + + virtual ExprHandle ConvertToNewArgs(ExprHandle* stmt, int group_index) { + LOG(FATAL) << "unmiplemented"; + return ExprHandle(); + } + + int output_group_count() const { + return outputs_.size(); + } + int output_group_size(int group_index) const { + CHECK(group_index >= 0 && group_index < (int)outputs_.size()); + return outputs_[group_index].size(); + } + LoopAxis* output(int group_index, int index) { + CHECK(group_index >= 0 && group_index < (int)outputs_.size()); + std::vector& output_group = outputs_[group_index]; + CHECK(index >= 0 && index < (int)output_group.size()); + return output_group[index]; + } + + int input_size() const { + return inputs_.size(); + } + + LoopAxis* input(int index) { + CHECK(index >= 0 && index < (int)inputs_.size()); + return inputs_[index]; + } + + void CloneFrom(const LoopAxisTransform* other); + + protected: + friend class ScheduleNode; + explicit LoopAxisTransform(const std::vector& inputs) + : inputs_(inputs) { + // TODO: find a better way to set schedule. + if (inputs.size() > 0ULL) { + this->set_schedule(inputs_[0]->schedule()); + } + } + + void set_output_group_count(int group_count) { + outputs_.resize(group_count); + } + + void set_output_group( + int group_index, + const std::vector& outputs) { + CHECK(group_index >= 0 && group_index < (int)outputs_.size()); + outputs_[group_index] = outputs; + for (LoopAxis* output : outputs) { + output->set_output_group_index(group_index); + } + } + + void mark_loop_axis_internal(LoopAxis* axis) { + axis->mark_as_internal(); + } + + // Override Schedule::NewAxis, but also sets current transform as the source. + LoopAxis* NewAxis(const VarHandle& loop_var, const Range& loop_range); + + private: + std::vector inputs_; // not owned + std::vector> outputs_; // not owened +}; + +// Basic class for the Split Axis transforms. +class TORCH_API SplitAxisTransform + : public Cloneable { + public: + using BaseClass = Cloneable; + void CloneFrom(const SplitAxisTransform* other); + ExprHandle start() { + return start_; + } + ExprHandle stop() { + return stop_; + } + int factor() { + return factor_; + } + bool factor_on_inner() { + return factor_on_inner_; + } + SplitAxisTransform() {} + + protected: + friend class ScheduleNode; + SplitAxisTransform(LoopAxis* loop_axis, int factor, bool factor_on_inner); + + private: + int factor_ = -1; + bool factor_on_inner_ = true; + ExprHandle start_; + ExprHandle stop_; +}; + +class SplitAxisWithTail + : public Cloneable { + public: + using BaseClass = Cloneable; + void CloneFrom(const SplitAxisWithTail* other); + Stmt* ConvertToNewArgs(Stmt* stmt, int output_group) override; + ExprHandle ConvertToNewArgs(ExprHandle* stmt, int output_group) override; + SplitAxisWithTail() {} + + private: + friend class ScheduleNode; + SplitAxisWithTail(LoopAxis* loop_axis, int factor, bool factor_on_inner); + ExprHandle combined_loop_index(int output_group); +}; + +class SplitAxisWithMask + : public Cloneable { + public: + using BaseClass = Cloneable; + void CloneFrom(const SplitAxisWithMask* other); + Stmt* ConvertToNewArgs(Stmt* stmt, int output_group) override; + ExprHandle ConvertToNewArgs(ExprHandle* stmt, int output_group) override; + SplitAxisWithMask() {} + const ExprHandle& predicate() const { + return predicate_; + } + + private: + friend class ScheduleNode; + SplitAxisWithMask(LoopAxis* loop_axis, int factor, bool factor_on_inner); + ExprHandle combined_loop_index(int output_group); + + ExprHandle predicate_; // original predicate +}; + +class FuseAxisTransform; + +// Section: Tensor ExprHandle Tree + +// A tensor expr operation within the expression tree. +// This is often a leaf node that corresponds subset of the operations from a +// user-specified tensor expression. +// This operation, combined with all ancestor axis/nodes in the tree, determines +// the semantics of this operation. +class TORCH_API TensorExprOp : public Cloneable { + public: + const Var* expr_var() const { + // TODO: Support multiple-output functions + CHECK(func_->func_vars().size() == 1); + return func_->func_var(0); + } + + const Expr* body() const { + // TODO: Support multiple-output functions + CHECK(func_->func_vars().size() == 1); + return func_->body(0); + } + + Function* func() const { + return func_; + } + + void CloneFrom(const TensorExprOp* other) { + this->func_ = other->func_; + this->element_stmt_ = other->element_stmt_; + this->predicates_ = other->predicates_; + } + + Stmt* ElementStmt() const { + return this->element_stmt_; + } + + void ApplyLoopTransform(LoopAxisTransform* loop_transform, int group_index) { + element_stmt_ = + loop_transform->ConvertToNewArgs(element_stmt_, group_index); + for (int i = 0; i < predicates_.size(); i++) { + predicates_[i] = + loop_transform->ConvertToNewArgs(&predicates_[i], group_index); + } + } + + void AddPredicate(const Expr* predicate) { + if (predicate) { + predicates_.push_back(ExprHandle(predicate)); + } + } + + const std::vector& predicates() const { + return predicates_; + } + + private: + friend class ScheduleNode; + TensorExprOp() {} + explicit TensorExprOp(Function* func) + : func_(func), element_stmt_(func_->ElementStmt(0)) { + // TODO: Support multiple-output functions + CHECK(func_->func_vars().size() == 1); + } + + // TODO: this needs more work. + // The ancestor-axes mark the region to evaluate expression. + // We still need to know the buffer this writes to. + Function* func_; + Stmt* element_stmt_; + std::vector predicates_; +}; + +// Part of the recursive node structure in the tensor expr tree. +// This variable type node could contain one of multiple types that follows: +// * A single loop axis +// * a tensor expr op. +class TORCH_API TensorExprNode + : public Cloneable { + public: + enum NodeType { + // These could show up in the tensor expression trees. + kEmptyValue, // The value in this node is empty, but could have siblings and + // children. + kOperation, // this node records an tensor expr op. + kAxis, // this node records a loop axis + }; + + NodeType node_type() const { + return node_value_.node_type; + } + + bool is_empty_value() const { + return node_value_.node_type == kEmptyValue; + } + bool is_tensor_expr_op() const { + return node_value_.node_type == kOperation; + } + bool is_loop_axis() const { + return node_value_.node_type == kAxis; + } + + TensorExprOp* tensor_expr_op() { + DCHECK(is_tensor_expr_op()); + DCHECK(node_value_.tensor_expr_op != nullptr); + return node_value_.tensor_expr_op; + } + const TensorExprOp* tensor_expr_op() const { + return const_cast(this)->tensor_expr_op(); + } + + LoopAxis* loop_axis() { + DCHECK(is_loop_axis()); + DCHECK(node_value_.loop_axis != nullptr); + return node_value_.loop_axis; + } + const LoopAxis* loop_axis() const { + return const_cast(this)->loop_axis(); + } + + TensorExprNode* parent() { + return parent_; + } + TensorExprNode* first_child() { + return first_child_; + } + TensorExprNode* next_sibling() { + return next_sibling_; + } + + void CloneFrom(const TensorExprNode* other); + + private: + friend class ScheduleNode; + + TensorExprNode() {} + + // Create a new node under the current node. + // Initialize the node list if it is still empty. + // Set the child's parent to this node. + TensorExprNode* NewNextSibling(); + TensorExprNode* NewFirstChild(); + + void SetNextSibling(TensorExprNode* node); + void SetFirstChild(TensorExprNode* node); + // Set the parent of this node, and all its siblings + void SetParent(TensorExprNode* parent); + + // Replace the subtree in "old_node" as the new subtree in "new_node". + // All relevant sibings and parents links in the "new_node" are updated. + // "old_node" might contain dangling pointers. + static void ReplaceSubtree( + TensorExprNode* old_node, + TensorExprNode* new_node); + + void set_tensor_expr_op(TensorExprOp* expr_op) { + DCHECK_EQ(node_value_.node_type, NodeType::kEmptyValue); + node_value_.node_type = kOperation; + node_value_.tensor_expr_op = expr_op; + } + + void set_loop_axis(LoopAxis* loop_axis) { + DCHECK_EQ(node_value_.node_type, NodeType::kEmptyValue); + node_value_.node_type = kAxis; + node_value_.loop_axis = loop_axis; + } + + // A variable-type that unions different value types for this node. + // TODO: isolate this into its own class, so different stage can have + // different value types. + struct NodeValue { + // A variable-type payload with this load. + NodeType node_type = kEmptyValue; + // node_type == kOperation, + TensorExprOp* tensor_expr_op = nullptr; + // node_type_ == kAxis, + LoopAxis* loop_axis = nullptr; + + void CloneFrom(const NodeValue* other); + }; + + // Data structures maintains the tensor expr tree. + TensorExprNode* next_sibling_ = nullptr; // the next sibling of this node + TensorExprNode* first_child_ = nullptr; // the first child of this node + TensorExprNode* parent_ = nullptr; // the parent node of this node + + // Payload multi-type value in this node. + NodeValue node_value_; +}; + +class TORCH_API ScheduleNode : public KernelScopedObject { + public: + // Section: user-facing functionalities. + ~ScheduleNode(); + + // Section: for schedule related internal functions. + LoopAxis* NewAxis(const VarHandle& loop_var, const Range& loop_range) { + return NewObject( + loop_var, loop_range, LoopAxis::kRegular, nullptr); + } + + SplitAxisWithTail* NewSplitAxisWithTail( + LoopAxis* loop_axis, + int factor, + bool factor_on_inner) { + return NewObject(loop_axis, factor, factor_on_inner); + } + + SplitAxisWithMask* NewSplitAxisWithMask( + LoopAxis* loop_axis, + int factor, + bool factor_on_inner) { + return NewObject(loop_axis, factor, factor_on_inner); + } + + TensorExprOp* NewTensorExprOp(Function* func) { + return NewObject(func); + } + + TensorExprNode* NewTensorExprNode() { + return NewObject(); + } + + // Create an object + template + T* NewObject(Args... args) { + T* p = new T(std::forward(args)...); + schedule_objects_.push_back(p); + p->set_schedule(this); + return p; + } + + void SplitWithTail( + TensorExprNode* expr_node, + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var, + VarHandle* tail_var, + TensorExprNode** tail_op); + + void SplitWithMask( + TensorExprNode* expr_node, + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var); + + void ComputeInline(TensorExprNode* expr_node); + + void GPUExecConfig( + TensorExprNode* expr_node, + const std::vector& blockIdx, + const std::vector& threadIdx); + + Stmt* Lower(); + + using CloneMap = std::unordered_map; + CloneMap& clone_map() { + return *clone_map_; + } + + // An RAII object to manage the clone-map for any potential cloning. + class ScopedCloneMap { + public: + ScopedCloneMap(ScheduleNode* schedule) : clone_map_(schedule->clone_map_) { + if (clone_map_) { + return; + } + clone_map_.reset(new CloneMap()); + map_initialized_ = true; + } + ~ScopedCloneMap() { + if (!map_initialized_) { + return; + } + clone_map_.reset(); + } + CloneMap& clone_map() { + return *clone_map_; + } + + private: + std::unique_ptr& clone_map_; + bool map_initialized_ = false; + }; + + template + friend Object* LookUpCloneObject(Object* object); + + template + friend Object* CloneObject(Object* object); + + private: + friend class Schedule; + explicit ScheduleNode(const std::vector& funcs); + ScheduleObject* CloneScheduleObject(ScheduleObject* object); + ScheduleObject* LookUpCloneScheduleObject(ScheduleObject* object); + Stmt* Lower(TensorExprNode* node); + Stmt* LowerNoSibling(TensorExprNode* node); + + std::vector output_tensors_; + std::vector internal_tensors_; + std::vector inlined_functions_; + TensorExprNode* root_node_ = nullptr; // not owned + std::vector schedule_objects_; // Owned + // a mapping between old and new objects during the clone process. + // whoever creates this map is responsible for releasing it. + std::unique_ptr clone_map_; + class DependencyTracker; + std::unique_ptr dependency_tracker_; +}; + +template +Object* LookUpCloneObject(Object* object) { + if (object == nullptr) { + return nullptr; + } + ScheduleNode* schedule = object->schedule(); + // TODO: switch to dynamic_cast + return static_cast(schedule->LookUpCloneScheduleObject(object)); +} + +template +Object* CloneObject(Object* object) { + if (object == nullptr) { + return nullptr; + } + ScheduleNode* schedule = object->schedule(); + ScheduleObject* new_object = schedule->CloneScheduleObject(object); + // TODO: switch to dynamic_cast when it becomes available. + return static_cast(new_object); +} + +class TORCH_API Schedule { + public: + static Schedule make(const std::vector& funcs) { + return Schedule(new ScheduleNode(funcs)); + } + + explicit Schedule(const std::vector& funcs) + : node_(new ScheduleNode(funcs)) {} + + Stmt* Lower() { + return node()->Lower(); + } + + Schedule(Schedule&& other) : node_(other.node_) { + other.node_ = nullptr; + } + + private: + // TODO: temporarily disable the copy. We should decide whether the semantics + // of this object. + Schedule(const Schedule&) = delete; + Schedule& operator=(const Schedule&) = delete; + Schedule(ScheduleNode* node) : node_(node) {} + ScheduleNode* node() { + return node_; + } + const ScheduleNode* node() const { + return node_; + } + + ScheduleNode* node_ = nullptr; +}; + +} // namespace schedule +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/stmt.h b/torch/csrc/jit/tensorexpr/stmt.h new file mode 100644 index 0000000000000..3f8c9a0ff194e --- /dev/null +++ b/torch/csrc/jit/tensorexpr/stmt.h @@ -0,0 +1,387 @@ +#pragma once + +#include +#include + +#include "torch/csrc/jit/tensorexpr/expr.h" +namespace torch { +namespace jit { +namespace tensorexpr { + +class Buffer; + +// The common base between all statement node. +class Stmt : public KernelScopedObject { + public: + Stmt() {} + TORCH_API virtual void accept(IRVisitor* visitor) const = 0; + virtual Stmt* accept_mutator(IRMutator* mutator) = 0; +}; + +template +class StmtNode : public Stmt { + public: + using StmtNodeBase = StmtNode; + void accept(IRVisitor* visitor) const override { + visitor->visit(static_cast(this)); + } + Stmt* accept_mutator(IRMutator* mutator) override; + StmtNode() {} +}; + +template +Stmt* StmtNode::accept_mutator(IRMutator* mutator) { + StmtNode* this_mutable = const_cast(this); + return mutator->mutate(static_cast(this_mutable)); +} + +// Concrete Stmt classes +class LetStmt : public StmtNode { + public: + const Var* var() const { + return var_; + } + + const Expr* value() const { + return value_; + } + + Stmt* body() const { + return body_; + } + + static Stmt* make(const VarHandle& var, const ExprHandle& value, Stmt* body) { + return new LetStmt(var.node(), value.node(), body); + } + + LetStmt(const Var* var, const Expr* value, Stmt* body) + : var_(var), value_(value), body_(body) {} + + private: + const Var* var_; + const Expr* value_; + Stmt* body_; +}; + +class Block : public StmtNode { + public: + static Stmt* make(const std::vector& stmts) { + std::vector valid_stmts; + for (size_t i = 0; i < stmts.size(); i++) { + if (!stmts[i]) { + continue; + } + valid_stmts.push_back(stmts[i]); + } + if (valid_stmts.empty()) { + return nullptr; + } + return new Block(valid_stmts); + } + int nstmts() const { + return stmts_.size(); + } + Stmt* stmt(int index) const { + return stmts_[index]; + } + + private: + explicit Block(const std::vector& stmts) : stmts_(stmts) {} + std::vector stmts_; +}; + +class TORCH_API Store : public StmtNode { + public: + const Var* base_handle() const { + return base_handle_; + } + const Expr* index() const { + return index_; + } + const Expr* value() const { + return value_; + } + const Expr* mask() const { + return mask_; + } + + static Stmt* make( + const Buffer& buffer, + const ExprHandle& index, + const ExprHandle& value, + const ExprHandle& mask) { + return new Store(buffer, index.node(), value.node(), mask.node()); + } + + static Stmt* make( + const VarHandle& base_handle, + const ExprHandle& index, + const ExprHandle& value, + const ExprHandle& mask) { + return new Store(base_handle.node(), index.node(), value.node(), mask.node()); + } + + static Stmt* make( + const VarHandle& base_handle, + const ExprHandle& index, + const ExprHandle& value) { + return new Store(base_handle.node(), index.node(), value.node(), ExprHandle(1).node()); + } + + // TODO: merge this with Load. + Store( + const Buffer& buffer, + const Expr* index, + const Expr* value, + const Expr* mask); + + Store( + const Var* base_handle, + const Expr* index, + const Expr* value, + const Expr* mask) + : base_handle_(base_handle), index_(index), value_(value), mask_(mask) { + CHECK_EQ(base_handle_->dtype(), kHandle); + CHECK_EQ(index->dtype().lanes(), mask->dtype().lanes()); + CHECK_EQ(index->dtype().lanes(), value->dtype().lanes()); + CHECK_EQ(index->dtype().scalar_type(), ScalarType::Int); + } + private: + + const Var* base_handle_; + const Expr* index_; + const Expr* value_; + const Expr* mask_; +}; + +// Allocate a buffer of given shapes and dtypes and bind it with the given +// buffer var. The life span is at most through the current program, until it is +// explicitly freed. An unfreed memory is likely considered an error. +class Allocate : public StmtNode { + public: + static Stmt* make( + const VarHandle& buffer_var, + Dtype dtype, + const std::vector& dims) { + std::vector dims_nodes(dims.size()); + for (size_t i = 0; i < dims.size(); i++) { + dims_nodes[i] = dims[i].node(); + } + return new Allocate(buffer_var.node(), dtype, dims_nodes); + } + + const Var* buffer_var() const { + return buffer_var_; + } + + Dtype dtype() const { + return dtype_; + } + + const std::vector& dims() const { + return dims_; + } + + Allocate(const Var* buffer_var, Dtype dtype, const std::vector& dims) + : buffer_var_(buffer_var), dtype_(dtype), dims_(dims) {} + + private: + const Var* buffer_var_; + Dtype dtype_; + std::vector dims_; + // TODO: add memory types. +}; + +// Free the specific buffer. It is an error. +class Free : public StmtNode { + public: + static Stmt* make(const VarHandle& buffer_var) { + return new Free(buffer_var.node()); + } + + const Var* buffer_var() const { + return buffer_var_; + } + + Free(const Var* buffer_var) : buffer_var_(buffer_var) {} + + private: + const Var* buffer_var_; +}; + +class Cond : public StmtNode { + public: + static Stmt* make( + const ExprHandle& condition, + Stmt* true_stmt, + Stmt* false_stmt) { + return new Cond(condition.node(), true_stmt, false_stmt); + } + + const Expr* condition() const { + return condition_; + } + + Stmt* true_stmt() const { + return true_stmt_; + } + + Stmt* false_stmt() const { + return false_stmt_; + } + + Cond(const Expr* condition, Stmt* true_stmt, Stmt* false_stmt) + : condition_(condition), true_stmt_(true_stmt), false_stmt_(false_stmt) {} + + private: + const Expr* condition_; + Stmt* true_stmt_; + Stmt* false_stmt_; +}; + +class LoopOptions { + public: + // GPU Block Index + bool is_gpu_block_index() const { + return gpu_block_index_ != -1; + } + + bool gpu_block_index() const { + return gpu_block_index_; + } + + std::string gpu_block_index_str() const { + DCHECK(is_gpu_block_index()); + static const char* kBlockIndexNames[] = { + "blockIdx.x", + "blockIdx.y", + "blockIdx.z", + "blockIdx.w", + }; + DCHECK(gpu_block_index_ >= 0 && gpu_block_index_ < 4); + return kBlockIndexNames[gpu_block_index_]; + } + + void set_gpu_block_index(int index) { + if (is_gpu_thread_index()) { + throw std::runtime_error("Cannot set both gpu block and thread index"); + } + if (is_gpu_block_index() && gpu_block_index() != index) { + throw std::runtime_error( + "Cannot set a previously set block index: " + + std::to_string(gpu_block_index()) + " vs " + std::to_string(index)); + } + gpu_block_index_ = index; + } + + // GPU Thread Index + bool is_gpu_thread_index() const { + return gpu_thread_index() != -1; + } + + int gpu_thread_index() const { + return gpu_thread_index_; + } + + std::string gpu_thread_index_str() const { + DCHECK(is_gpu_thread_index()); + static const char* kThreadIndexNames[] = { + "threadIdx.x", "threadIdx.y", "threadIdx.z", "threadIdx.w"}; + DCHECK(gpu_thread_index_ >= 0 && gpu_thread_index_ < 4); + return kThreadIndexNames[gpu_thread_index_]; + } + + void set_gpu_thread_index(int index) { + if (is_gpu_block_index()) { + throw std::runtime_error("Cannot set both gpu thread and block index"); + } + if (is_gpu_thread_index() && gpu_thread_index() != index) { + throw std::runtime_error( + "Cannot set a previously set thread index: " + + std::to_string(gpu_thread_index()) + " vs " + std::to_string(index)); + } + gpu_thread_index_ = index; + } + + std::string ToString() const { + std::ostringstream oss; + if (is_gpu_block_index()) { + oss << gpu_block_index_str(); + } else if (is_gpu_thread_index()) { + oss << gpu_thread_index_str(); + } + return oss.str(); + } + + private: + int gpu_block_index_ = -1; + int gpu_thread_index_ = -1; +}; + +class For : public StmtNode { + public: + const Var* var() const { + return var_; + } + const Expr* start() const { + return start_; + } + const Expr* stop() const { + return stop_; + } + Stmt* body() const { + return body_; + } + static Stmt* make( + const VarHandle& var, + const ExprHandle& start, + const ExprHandle& stop, + Stmt* body) { + if (!body) { + return nullptr; + } + return new For(var.node(), start.node(), stop.node(), body); + } + static Stmt* make( + const VarHandle& var, + const ExprHandle& start, + const ExprHandle& stop, + Stmt* body, + const LoopOptions& loop_options) { + if (!body) { + return nullptr; + } + return new For(var.node(), start.node(), stop.node(), body, loop_options); + } + const LoopOptions loop_options() const { + return loop_options_; + } + + For(const Var* var, const Expr* start, const Expr* stop, Stmt* body) + : var_(var), start_(start), stop_(stop), body_(body) { + CHECK(var && start && stop && body); + } + + For(const Var* var, + const Expr* start, + const Expr* stop, + Stmt* body, + const LoopOptions& loop_options) + : var_(var), + start_(start), + stop_(stop), + body_(body), + loop_options_(loop_options) { + CHECK(var && start && stop && body); + } + + private: + const Var* var_; + const Expr* start_; + const Expr* stop_; + Stmt* body_; + LoopOptions loop_options_; +}; +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp new file mode 100644 index 0000000000000..ae72aac5e81ba --- /dev/null +++ b/torch/csrc/jit/tensorexpr/tensor.cpp @@ -0,0 +1,78 @@ +#include "torch/csrc/jit/tensorexpr/tensor.h" +#include "torch/csrc/jit/tensorexpr/schedule.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +using schedule::TensorExprNode; +// using schedule::ScheduleNode; + +void TensorOperation::SplitWithTail( + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var, + VarHandle* tail_var, + TensorOperation** tail_op) { + check_expr_node(); + schedule::ScheduleNode* schedule = expr_node_->schedule(); + schedule::TensorExprNode* tail_expr_node = nullptr; + schedule->SplitWithTail( + expr_node_, + loop_var, + factor, + factor_on_inner, + outer_var, + inner_var, + tail_var, + &tail_expr_node); + if (!tail_expr_node) { + *tail_op = new TensorOperation(tail_expr_node); + } +} + +void TensorOperation::SplitWithMask( + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var) { + check_expr_node(); + schedule::ScheduleNode* schedule = expr_node_->schedule(); + schedule::TensorExprNode* tail_expr_node = nullptr; + schedule->SplitWithMask( + expr_node_, loop_var, factor, factor_on_inner, outer_var, inner_var); +} + +void TensorOperation::GPUExecConfig( + const std::vector& blockIdx, + const std::vector& threadIdx) { + check_expr_node(); + schedule::ScheduleNode* schedule = expr_node_->schedule(); + schedule->GPUExecConfig(expr_node_, blockIdx, threadIdx); +} + +void TensorOperation::ComputeInline() { + // TODO: find a better way to detect that no schedule might be created for this. + // Even though this operation might be used at the Torch JIT level, it might be + // still be pruned out at the expression level, such as "y = rand_like(x)". + // For now, we tentatively treat as if this tensor is not part of the schedule. + if (expr_node_ == nullptr) { + return; + } + schedule::ScheduleNode* schedule = expr_node_->schedule(); + schedule->ComputeInline(expr_node_); +} + +void TensorOperation::check_expr_node() { + if (expr_node_ == nullptr) { + throw std::runtime_error( + "expr_node in this tensor is null. It is likely that no schedule is attached."); + } +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h new file mode 100644 index 0000000000000..60f7b8415b88a --- /dev/null +++ b/torch/csrc/jit/tensorexpr/tensor.h @@ -0,0 +1,202 @@ +#pragma once + +#include +#include + +#include "torch/csrc/jit/tensorexpr/expr.h" +#include "torch/csrc/jit/tensorexpr/function.h" + +namespace torch { +namespace jit { +namespace tensorexpr { +namespace schedule { +class TensorExprNode; +class ScheduleNode; +} // namespace schedule + +using schedule::TensorExprNode; + +class TORCH_API TensorOperation : public KernelScopedObject { + public: + void SplitWithTail( + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var, + VarHandle* tail_var, + TensorOperation** tail_op); + + void SplitWithMask( + const VarHandle& loop_var, + int factor, + bool factor_on_inner, + VarHandle* outer_var, + VarHandle* inner_var); + + void ComputeInline(); + + void GPUExecConfig( + const std::vector& blockIdx, + const std::vector& threadIdx); + + TensorExprNode* expr_node() { + return expr_node_; + } + + protected: + TensorOperation() {} + explicit TensorOperation(TensorExprNode* expr_node) : expr_node_(expr_node) {} + + private: + void check_expr_node(); + + friend class schedule::ScheduleNode; + TensorExprNode* expr_node_ = nullptr; +}; + +class Tensor : public TensorOperation { + public: + Function* function() const { + return function_; + } + int output_index() const { + return output_index_; + } + + + // Wrappers over accessors to fields of the underlying function + const Expr* body() const { + return function()->body(output_index()); + } + const Var* func_var() const { + return function()->func_var(output_index()); + } + int ndim() const { + return function()->dims().size(); + } + const Expr* dim(int index) const { + return function()->dim(index); + } + const std::vector& dims() const { + return function()->dims(); + } + const Var* arg(int index) const { + return function()->arg(index); + } + const std::vector& args() const { + return function()->args(); + } + + Tensor(Function* function, int output_index) + : function_(function), output_index_(output_index) {} + template + inline ExprHandle operator()(const Ts&... ts); + template + inline ExprHandle call(const std::vector& args); + template + inline ExprHandle call(const Ts&... ts); + + private: + Function* function_; + int output_index_; +}; + +// A helper structure to store the arguments to specify dimensions. In the +// Compute arugments for dim_args, all of the following is supported. For +// example: +// dim_args: {1, 2, 3, 4} +// dim_args: {{1, "x"}, {2, "y"}, {3, "z"}} +// dim_args: {1, 2, {3, "x"}} +class DimArg { + public: + // Intentionally leave out explicit to allow implicit conversions. + DimArg(const ExprHandle& dim) : dim_(dim) {} + DimArg(const ExprHandle& dim, const std::string& name_hint) + : dim_(dim), name_hint_(name_hint) {} + const ExprHandle& dim() const { + return dim_; + } + const std::string& name_hint() const { + return name_hint_; + } + + private: + ExprHandle dim_; + std::string name_hint_; +}; + +TORCH_API Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func); +TORCH_API Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func); +TORCH_API Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function body_func); +TORCH_API Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function + body_func); +TORCH_API Tensor* Compute( + const std::string& func_name, + const std::vector& dim_args, + std::function&)> body_func); + +class FunctionCall : public CallNode { + public: + using BaseClass = CallNode; + static ExprHandle make(Tensor* tensor, const std::vector& params) { + std::vector params_nodes(params.size()); + for (size_t i = 0; i < params.size(); i++) { + params_nodes[i] = params[i].node(); + } + return ExprHandle(new FunctionCall(tensor, params_nodes)); + } + + const Tensor* tensor() const { + return tensor_; + } + Tensor* tensor() { + return tensor_; + } + + FunctionCall(Tensor* tensor, const std::vector& params) + : BaseClass(tensor->function()->body(tensor->output_index())->dtype(), kFunctionCall, params), + tensor_(tensor) {} + private: + const Expr* DefaultMutator(const std::vector& new_params) const override { + return new FunctionCall(tensor_, new_params); + } + + std::string func_name() const { + return tensor_->func_var()->name_hint(); + } + + Tensor* tensor_; +}; +template +inline ExprHandle Tensor::operator()(const Ts&... ts) { + std::vector params({ExprHandle(ts)...}); + return FunctionCall::make(this, std::move(params)); +} + +template +inline ExprHandle Tensor::call(const Ts&... ts) { + std::vector params({ExprHandle(ts)...}); + return FunctionCall::make(this, std::move(params)); +} + +template +inline ExprHandle Tensor::call(const std::vector& args) { + std::vector params(args.begin(), args.end()); + return FunctionCall::make(this, params); +} +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp new file mode 100644 index 0000000000000..d4196c1c54426 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -0,0 +1,157 @@ +#include "torch/csrc/jit/tensorexpr/types.h" +#include + +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +bool is_integral(const ScalarType& type) { + switch (type) { + case ScalarType::Byte: + case ScalarType::Char: + case ScalarType::Short: + case ScalarType::Int: + case ScalarType::Long: + return true; + default: + return false; + } + + return false; +} + +bool is_floating_point(const ScalarType& type) { + switch (type) { + case ScalarType::Half: + case ScalarType::Float: + case ScalarType::Double: + return true; + default: + return false; + } + + return false; +} + +Dtype Dtype::scalar_dtype() const { + return ToDtype(scalar_type_); +} + +#define DTYPE_DEFINE(_1, n) \ + TORCH_API Dtype k##n(ScalarType::n, 1); + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, DTYPE_DEFINE) + +#undef DTYPE_DEFINE + +TORCH_API Dtype kHandle(ScalarType::Handle, 1); +TORCH_API Dtype kUninitialized(ScalarType::Uninitialized, 1); + +Dtype ToDtype(ScalarType type) { + switch (type) { +#define TYPE_CASE(_1, n) \ + case ScalarType::n: \ + return k##n; + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE) +#undef TYPE_CASE + + case ScalarType::Handle: + return kHandle; + case ScalarType::Uninitialized: + return kUninitialized; + default: + LOG(FATAL) << "invalid scalar type: " << type; + return kUninitialized; + } +} + +TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype) { + stream << dtype.scalar_type_; + if (dtype.lanes() > 1) { + stream << "x" << dtype.lanes(); + ; + } + return stream; +} + +TORCH_API std::ostream& operator<<( + std::ostream& stream, const ScalarType& type) { + switch (type) { +#define TYPE_CASE(ttt, Name) \ + case ScalarType::Name: \ + stream << #ttt; \ + break; + + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + + case ScalarType::Undefined: + stream << "Undefined"; + break; + case ScalarType::Handle: + stream << "Handle"; + break; + case ScalarType::Uninitialized: + stream << "Uninitialized"; + break; + case ScalarType::None: + stream << "None"; + break; + default: + LOG(FATAL) << "invalid scalar type: " << (int)type; + } + return stream; +} + +int Dtype::byte_size() const { + int scalar_size = -1; + switch (scalar_type_) { +#define TYPE_CASE(Type, Name) \ + case ScalarType::Name: \ + scalar_size = sizeof(Type); \ + break; + + AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, TYPE_CASE); +#undef TYPE_CASE + default: + throw std::runtime_error( + "invalid scalar type; " + std::to_string(scalar_type_)); + } + return scalar_size * lanes(); +} + +std::string Dtype::ToCppString() const { + switch (scalar_type_) { +#define TYPE_CASE(t, n) \ + case ScalarType::n: \ + return #t; + AT_FORALL_SCALAR_TYPES_AND(Bool, TYPE_CASE); +#undef TYPE_CASE + case ScalarType::Half: + return "half"; + default: + throw std::runtime_error("Invalid dtype: " + std::to_string(scalar_type_)); + } +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +namespace std { + +std::string to_string(const Dtype& dtype) { + std::ostringstream oss; + oss << dtype; + return oss.str(); +} + +std::string to_string(const ScalarType& type) { + std::ostringstream oss; + oss << type; + return oss.str(); +} + +} // namespace std diff --git a/torch/csrc/jit/tensorexpr/types.h b/torch/csrc/jit/tensorexpr/types.h new file mode 100644 index 0000000000000..bbd7e19064784 --- /dev/null +++ b/torch/csrc/jit/tensorexpr/types.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +using int32 = std::int32_t; + +class Dtype; +TORCH_API std::ostream& operator<<(std::ostream& stream, const Dtype& dtype); + +// Switch to PT/Aten dtypes +enum class ScalarType : int8_t { +#define DEFINE_ENUM(_1, n) n, + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX_AND_QINTS(DEFINE_ENUM) +#undef DEFINE_ENUM + // Undefined must be next to match c10::ScalarType; + Undefined, + Handle, + Uninitialized, + None, + NumOptions +}; + +TORCH_API std::ostream& operator<<( + std::ostream& stream, const ScalarType& dtype); + +TORCH_API bool is_integral(const ScalarType& type); +TORCH_API bool is_floating_point(const ScalarType& type); + +// Data types for scalar and vector elements. +class TORCH_API Dtype { + public: + explicit Dtype(int8_t type) + : scalar_type_(static_cast(type)), lanes_(1) {} + explicit Dtype(ScalarType type) + : scalar_type_(type), lanes_(1) {} + Dtype(int8_t type, int lanes) + : scalar_type_(static_cast(type)), lanes_(lanes) {} + Dtype(ScalarType type, int lanes) + : scalar_type_(type), lanes_(lanes) {} + Dtype(Dtype type, int lanes) + : scalar_type_(type.scalar_type_), lanes_(lanes) { + CHECK(type.lanes() == 1); + } + int lanes() const { + return lanes_; + } + ScalarType scalar_type() const { return scalar_type_; } + Dtype scalar_dtype() const; + bool operator==(const Dtype& other) const { + return scalar_type_ == other.scalar_type_ && lanes_ == other.lanes_; + } + bool operator!=(const Dtype& other) const { + return !(*this == other); + } + int byte_size() const; + std::string ToCppString() const; + + bool is_integral() const { return tensorexpr::is_integral(scalar_type_); } + bool is_floating_point() const { return tensorexpr::is_floating_point(scalar_type_); } + + private: + friend std::ostream& operator<<(std::ostream& stream, const Dtype& dtype); + ScalarType scalar_type_; + int lanes_; // the width of the element for a vector time +}; + +extern TORCH_API Dtype kUninitialized; +extern TORCH_API Dtype kHandle; + +#define NNC_DTYPE_DECLARATION(ctype,name) \ + extern TORCH_API Dtype k##name; + +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_DTYPE_DECLARATION) +#undef NNC_DTYPE_DECLARATION + +template +TORCH_API Dtype ToDtype(); + +#define NNC_TODTYPE_DECLARATION(ctype,name) \ + template <> \ + inline Dtype ToDtype() { \ + return k##name; \ + } +AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, NNC_TODTYPE_DECLARATION) +#undef NNC_TODTYPE_DECLARATION + +TORCH_API Dtype ToDtype(ScalarType type); + +// Call c10 type promotion directly. +inline ScalarType promoteTypes(ScalarType a, ScalarType b) { + return static_cast(c10::promoteTypes( + static_cast(a), static_cast(b))); +} +inline ScalarType promoteTypes(Dtype a, Dtype b) { + return static_cast(c10::promoteTypes( + static_cast(a.scalar_type()), + static_cast(b.scalar_type()))); +} + +inline Dtype BinaryOpDtype( + Dtype op1_dtype, + Dtype op2_dtype, + ScalarType ret_type = ScalarType::None) { + if (op1_dtype == op2_dtype) { + if (ret_type == ScalarType::None) { + return op1_dtype; + } + + return ToDtype(ret_type); + } + + CHECK_EQ(op1_dtype.lanes(), op2_dtype.lanes()) << "vector lengths must match"; + int lanes = op1_dtype.lanes(); + + ScalarType resultType = promoteTypes(op1_dtype, op2_dtype); + CHECK_NE(resultType, ScalarType::Undefined) + << "Invalid dtypes: " << op1_dtype << ", " << op2_dtype; + + if (lanes == 1) { + // Use the fixed scalar Dtypes. + return ToDtype(resultType); + } + + return Dtype(resultType, lanes); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch + +namespace std { + +using torch::jit::tensorexpr::Dtype; +std::string to_string(const Dtype& dtype); +using torch::jit::tensorexpr::ScalarType; +std::string to_string(const ScalarType& dtype); + +} // namespace std diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.cpp b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp new file mode 100644 index 0000000000000..d7f333eed5b5a --- /dev/null +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.cpp @@ -0,0 +1,48 @@ +#include "torch/csrc/jit/tensorexpr/unique_name_manager.h" + +#include +#include "torch/csrc/jit/tensorexpr/ir.h" + +namespace torch { +namespace jit { +namespace tensorexpr { + +const std::string& UniqueNameManager::get_unique_name(const Var* v) { + // Find if we have already encountered this variable. + auto iter = unique_name_mapping_.find(v); + if (iter != unique_name_mapping_.end()) { + return iter->second; + } + + // First use the name_hint as a prefix to check if there is another name + // with the same prefix. + std::string name_hint = v->name_hint(); + if (name_hint == "") { + name_hint = "v"; + } else if (std::isdigit(name_hint[0])) { + name_hint = "v" + name_hint; + } + int& count = unique_name_count_[name_hint]; + while (true) { + // Even if with a new count, this name might already be used. For example + // ("x", 1) could collidewith ("x_1", 0) + int count_v = count++; + std::string unique_name = name_hint; + if (count_v > 0) { + unique_name += "_" + std::to_string(count_v); + } + if (all_unique_names_.count(unique_name) == 0) { + all_unique_names_.insert(unique_name); + auto result = unique_name_mapping_.insert(std::make_pair(v, unique_name)); + return result.first->second; + } + } +} + +const std::string& UniqueNameManager::get_unique_name(const VarHandle& v) { + return get_unique_name(v.node()); +} + +} // namespace tensorexpr +} // namespace jit +} // namespace torch diff --git a/torch/csrc/jit/tensorexpr/unique_name_manager.h b/torch/csrc/jit/tensorexpr/unique_name_manager.h new file mode 100644 index 0000000000000..6bb669e57ba5f --- /dev/null +++ b/torch/csrc/jit/tensorexpr/unique_name_manager.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +#include + +namespace torch { +namespace jit { +namespace tensorexpr { + +class VarHandle; +class Var; + +using VarNameMap = std::unordered_map; + +// A manager to get unique names from vars. +// It starts with the name hints of the var and append "_" + $counter until it +// hits a unique name. +class TORCH_API UniqueNameManager { + public: + const std::string& get_unique_name(const VarHandle& v); + + const std::string& get_unique_name(const Var* v); + + private: + friend class ScopedVarName; + VarNameMap unique_name_mapping_; + std::unordered_map unique_name_count_; + std::unordered_set all_unique_names_; +}; + +} // namespace tensorexpr +} // namespace jit +} // namespace torch