|
| 1 | +import torch |
| 2 | +import os |
| 3 | +import sys |
| 4 | +import numpy as np |
| 5 | +from torch.autograd import Variable |
| 6 | +from torch import jit |
| 7 | +from torch import Tensor |
| 8 | +from typing import Tuple |
| 9 | + |
| 10 | + |
| 11 | +# Ugly code that expands the fake_uq to the shape we need as an output |
| 12 | +def to_tupple(y: Tensor, fake_uq: Tensor) -> Tuple[Tensor, Tensor]: |
| 13 | + outer_dim = y.shape[0] |
| 14 | + fake_uq_dim = fake_uq.shape[0] |
| 15 | + tmp = fake_uq.clone().detach() |
| 16 | + additional_dims = torch.div(outer_dim, fake_uq_dim, rounding_mode="floor") + outer_dim % fake_uq_dim |
| 17 | + final_shape = (additional_dims * fake_uq_dim, *fake_uq.shape[1:]) |
| 18 | + tmp = tmp.unsqueeze(0) |
| 19 | + my_list = [1] * len(fake_uq.shape) |
| 20 | + new_dims = (additional_dims, *my_list) |
| 21 | + tmp = tmp.repeat(new_dims) |
| 22 | + tmp = tmp.reshape(final_shape) |
| 23 | + std = tmp[: y.shape[0], ...] |
| 24 | + return y, std |
| 25 | + |
| 26 | + |
| 27 | +class TuppleModel(torch.nn.Module): |
| 28 | + def __init__(self, inputSize, outputSize, fake_uq): |
| 29 | + super(TuppleModel, self).__init__() |
| 30 | + self.linear = torch.nn.Linear(inputSize, outputSize) |
| 31 | + self.linear.weight.data.fill_(0.0) |
| 32 | + self.linear.bias.data.fill_(0.0) |
| 33 | + self.fake_uq = torch.nn.Parameter(fake_uq, requires_grad=False) |
| 34 | + |
| 35 | + def forward(self, x): |
| 36 | + y = self.linear(x) |
| 37 | + return to_tupple(y, self.fake_uq) |
| 38 | + |
| 39 | + |
| 40 | +def main(args): |
| 41 | + inputDim = int(args[1]) |
| 42 | + outputDim = int(args[2]) |
| 43 | + device = args[3] |
| 44 | + uq_type = args[4] |
| 45 | + precision = args[5] |
| 46 | + output_name = args[6] |
| 47 | + enable_cuda = True |
| 48 | + if device == "cuda": |
| 49 | + enable_cuda = True |
| 50 | + suffix = "_gpu" |
| 51 | + elif device == "cpu": |
| 52 | + enable_cuda = False |
| 53 | + suffix = "_cpu" |
| 54 | + prec = torch.float32 |
| 55 | + if precision == "double": |
| 56 | + prec = torch.double |
| 57 | + |
| 58 | + fake_uq = torch.rand(2, outputDim, dtype=prec) |
| 59 | + if uq_type == "mean": |
| 60 | + # This sets odd uq to less than 0.5 |
| 61 | + fake_uq[0, ...] *= 0.5 |
| 62 | + # This sets even uq to larger than 0.5 |
| 63 | + fake_uq[1, ...] = 0.5 + 0.5 * (fake_uq[1, ...]) |
| 64 | + elif uq_type == "max": |
| 65 | + max_val = torch.max(fake_uq, axis=1).values |
| 66 | + scale = 0.49 / max_val |
| 67 | + fake_uq *= scale.unsqueeze(0).T |
| 68 | + fake_uq[0, int(outputDim / 2)] = 0.51 |
| 69 | + else: |
| 70 | + print("Unknown uq type") |
| 71 | + sys.exit() |
| 72 | + if precision == "double": |
| 73 | + model = TuppleModel(inputDim, outputDim, fake_uq).double() |
| 74 | + else: |
| 75 | + model = TuppleModel(inputDim, outputDim, fake_uq) |
| 76 | + |
| 77 | + if torch.cuda.is_available() and enable_cuda: |
| 78 | + model = model.cuda() |
| 79 | + |
| 80 | + model.eval() |
| 81 | + |
| 82 | + data = torch.randn(1023, inputDim, dtype=prec) |
| 83 | + |
| 84 | + with torch.jit.optimized_execution(True): |
| 85 | + traced = torch.jit.trace(model, (torch.randn(inputDim, dtype=prec).to(device),)) |
| 86 | + traced.save(f"{output_name}") |
| 87 | + |
| 88 | + data = torch.zeros(2, inputDim, dtype=prec) |
| 89 | + inputs = Variable(data.to(device)) |
| 90 | + model = jit.load(f"{output_name}") |
| 91 | + model.eval() |
| 92 | + with torch.no_grad(): |
| 93 | + print("Ouput", model(inputs)) |
| 94 | + |
| 95 | + |
| 96 | +if __name__ == "__main__": |
| 97 | + main(sys.argv) |
0 commit comments