diff --git a/helpers/sh_functions.py b/helpers/sh_functions.py new file mode 100644 index 0000000..21b637f --- /dev/null +++ b/helpers/sh_functions.py @@ -0,0 +1,110 @@ +import math +import numpy as np +import time +import torch +import torchvision + +#Convolve the SH coefficients with a low pass filter in spatial domain +def deringing(coeffs, window): + deringed_coeffs = torch.zeros_like(coeffs) + deringed_coeffs[:, 0] += coeffs[:, 0] + deringed_coeffs[:, 1:1 + 3] += \ + coeffs[:, 1:1 + 3] * math.pow(math.sin(math.pi * 1.0 / window) / (math.pi * 1.0 / window), 4.0) + deringed_coeffs[:, 4:4 + 5] += \ + coeffs[:, 4:4 + 5] * math.pow(math.sin(math.pi * 2.0 / window) / (math.pi * 2.0 / window), 4.0) + return deringed_coeffs + +# Spherical harmonics functions +def P(l, m, x): + pmm = 1.0 + if(m>0): + somx2 = np.sqrt((1.0-x)*(1.0+x)) + fact = 1.0 + for i in range(1,m+1): + pmm *= (-fact) * somx2 + fact += 2.0 + + if(l==m): + return pmm * np.ones(x.shape) + + pmmp1 = x * (2.0*m+1.0) * pmm + + if(l==m+1): + return pmmp1 + + pll = np.zeros(x.shape) + for ll in range(m+2, l+1): + pll = ( (2.0*ll-1.0)*x*pmmp1-(ll+m-1.0)*pmm ) / (ll-m) + pmm = pmmp1 + pmmp1 = pll + + return pll + +def factorial(x): + if(x == 0): + return 1.0 + return x * factorial(x-1) + +def K(l, m): + return np.sqrt( ((2 * l + 1) * factorial(l-m)) / (4*np.pi*factorial(l+m)) ) + +def SH(l, m, theta, phi): + sqrt2 = np.sqrt(2.0) + if(m==0): + if np.isscalar(phi): + return K(l,m)*P(l,m,np.cos(theta)) + else: + return K(l,m)*P(l,m,np.cos(theta))*np.ones(phi.shape) + elif(m>0): + return sqrt2*K(l,m)*np.cos(m*phi)*P(l,m,np.cos(theta)) + else: + return sqrt2*K(l,-m)*np.sin(-m*phi)*P(l,-m,np.cos(theta)) + +def shEvaluate(theta, phi, lmax): + if np.isscalar(theta): + coeffsMatrix = np.zeros((1,1,shTerms(lmax))) + else: + coeffsMatrix = np.zeros((theta.shape[0],phi.shape[0],shTerms(lmax))) + + for l in range(0,lmax+1): + for m in range(-l,l+1): + index = shIndex(l, m) + coeffsMatrix[:,:,index] = SH(l, m, theta, phi) + return coeffsMatrix + +def getCoeeficientsMatrix(xres,lmax=2): + yres = int(xres/2) + # setup fast vectorisation + x = np.arange(0,xres) + y = np.arange(0,yres).reshape(yres,1) + + # Setup polar coordinates + latLon = xy2ll(x,y,xres,yres) + + # Compute spherical harmonics. Apply thetaOffset due to EXR spherical coordiantes + Ylm = shEvaluate(latLon[0], latLon[1], lmax) + return Ylm + +def shReconstructSignal(coeffs, sh_basis_matrix=None, width=512): + if sh_basis_matrix is None: + lmax = sh_lmax_from_terms(coeffs.shape[0]) + sh_basis_matrix = getCoeeficientsMatrix(width,lmax) + sh_basis_matrix_t = torch.from_numpy(sh_basis_matrix).to(coeffs).float() + return (torch.matmul(sh_basis_matrix_t,coeffs)).to(coeffs).float() + +def shTerms(lmax): + return (lmax + 1) * (lmax + 1) + +def sh_lmax_from_terms(terms): + return int(np.sqrt(terms)-1) + +def shIndex(l, m): + return l*l+l+m + +def xy2ll(x,y,width,height): + def yLocToLat(yLoc, height): + return (yLoc / (float(height)/np.pi)) + def xLocToLon(xLoc, width): + return (xLoc / (float(width)/(np.pi * 2))) + return np.asarray([yLocToLat(y, height), xLocToLon(x, width)]) + diff --git a/images/input.jpg b/images/input.jpg new file mode 100644 index 0000000..911952e Binary files /dev/null and b/images/input.jpg differ diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..1ca982d --- /dev/null +++ b/inference.py @@ -0,0 +1,74 @@ +import argparse +import os +import cv2 +import sys +import numpy as np +import torch +import torchvision +from helpers.sh_functions import * +from loaders.Illum_loader import IlluminationModule, Inference_Data +from loaders.autoenc_ldr2hdr import LDR2HDR +from torch.utils.data import DataLoader + +def parse_arguments(args): + usage_text = ( + "Inference script for Deep Lighting Environment Map Estimation from Spherical Panoramas" + "Usage: python3 inference.py --input_path " + ) + parser = argparse.ArgumentParser(description=usage_text) + parser.add_argument('--input_path', type=str, default='./images/input.jpg', help="Input panorama color image file") + parser.add_argument('--out_path', type=str, default='./output/', help='Output folder for the predicted environment map panorama') + parser.add_argument('-g','--gpu', type=str, default='0', help='GPU id of the device to use. Use -1 for CPU.') + parser.add_argument("--chkpnt_path", default='./models/model.pth', type=str, help='Pre-trained checkpoint file for lighting regression module') + parser.add_argument('--ldr2hdr_model', type=str, default='./models/ldr2hdr.pth', help='Pre-trained checkpoint file for ldr2hdr image translation module') + parser.add_argument("--width", type=float, default=512, help = "Spherical panorama image width.") + parser.add_argument('--deringing', type=int, default=0, help='Enable low pass deringing filter for the predicted SH coefficients') + parser.add_argument('--dr_window', type=float, default='6.0') + return parser.parse_known_args(args) + +def evaluate( + illumination_module: torch.nn.Module, + ldr2hdr_module: torch.nn.Module, + args: argparse.Namespace, + device: torch.device +): + if (os.path.isdir(args.out_path)!=True): + os.mkdir(args.out_path) + + in_filename, in_file_extention = os.path.splitext(args.input_path) + assert in_file_extention in ['.png','.jpg'] + inference_data = Inference_Data(args.input_path) + out_path = args.out_path + os.path.basename(args.input_path) + out_filename, out_file_extension = os.path.splitext(out_path) + out_file_extension = '.exr' + out_path = out_filename + out_file_extension + dataloader = DataLoader(inference_data, batch_size=1, shuffle=False, num_workers=1) + for i, data in enumerate(dataloader): + input_img = data.to(device).float() + with torch.no_grad(): + start_time = time.time() + right_rgb = ldr2hdr_module(input_img) + p_coeffs = illumination_module(right_rgb).view(1,9,3).to(device).float() + if args.deringing: + p_coeffs = deringing(p_coeffs, args.dr_window).to(device).float() + elapsed_time = time.time() - start_time + print("Elapsed inference time: %2.4fsec" % elapsed_time) + pred_env_map = shReconstructSignal(p_coeffs.squeeze(0), width=args.width) + cv2.imwrite(out_path, pred_env_map.cpu().detach().numpy()) + +def main(args): + device = torch.device("cuda:" + str(args.gpu) if (torch.cuda.is_available() and int(args.gpu) >= 0) else "cpu") + # load lighting module + illumination_module = IlluminationModule(batch_size=1).to(device) + illumination_module.load_state_dict(torch.load(args.chkpnt_path)) + print("Lighting moduled loaded") + # load LDR2HDR module + ldr2hdr_module = LDR2HDR() + ldr2hdr_module.load_state_dict(torch.load(args.ldr2hdr_model)['state_dict_G']) + ldr2hdr_module = ldr2hdr_module.to(device) + print("LDR2HDR moduled loaded") + evaluate(illumination_module, ldr2hdr_module, args, device) + +if __name__ == '__main__': + args, unknown = parse_arguments(sys.argv) + main(args) \ No newline at end of file diff --git a/loaders/Illum_loader.py b/loaders/Illum_loader.py new file mode 100644 index 0000000..1df41c4 --- /dev/null +++ b/loaders/Illum_loader.py @@ -0,0 +1,61 @@ +from skimage import io, transform +import numpy as np +import cv2 +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset +from torchvision import transforms, utils +''' + Input (256,512,3) +''' +class IlluminationModule(nn.Module): + def __init__(self, batch_size): + super().__init__() + self.cv_block1 = conv_bn_elu(3, 64, kernel_size=7, stride=2) + self.cv_block2 = conv_bn_elu(64, 128, kernel_size=5, stride=2) + self.cv_block3 = conv_bn_elu(128, 256, stride=2) + self.cv_block4 = conv_bn_elu(256, 256) + self.cv_block5 = conv_bn_elu(256, 256, stride=2) + self.cv_block6= conv_bn_elu(256, 256) + self.cv_block7 = conv_bn_elu(256, 256, stride=2) + self.fc = nn.Linear(256*16*8, 2048) + '''One head regression''' + self.sh_fc = nn.Linear(2048, 27) + + def forward(self, x): + x = self.cv_block1(x) + x = self.cv_block2(x) + x = self.cv_block3(x) + x = self.cv_block4(x) + x = self.cv_block5(x) + x = self.cv_block6(x) + x = self.cv_block7(x) + x = x.view(-1, 256*8*16) + x = F.elu(self.fc(x)) + return((self.sh_fc(x))) + +def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True): + # conv layer with ELU activation function + pad = int(kernel_size/2) + if padding is False: + pad = 0 + return nn.Sequential( + nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad), + nn.ELU(), + ) + +class Inference_Data(Dataset): + def __init__(self, img_path): + self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) + self.input_img = cv2.resize(self.input_img, (512,256), interpolation=cv2.INTER_CUBIC) + self.to_tensor = transforms.ToTensor() + self.data_len = 1 + + def __getitem__(self, index): + self.tensor_img = self.to_tensor(self.input_img) + return self.tensor_img + + def __len__(self): + return self.data_len \ No newline at end of file diff --git a/loaders/autoenc_ldr2hdr.py b/loaders/autoenc_ldr2hdr.py new file mode 100644 index 0000000..97cfd7a --- /dev/null +++ b/loaders/autoenc_ldr2hdr.py @@ -0,0 +1,80 @@ +#Autoencoder for LDR to HDR image mapping + +from torch import nn +import torch +from torchvision import models +import torchvision +from torch.nn import functional as F + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find('BatchNorm') != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + +class LDR2HDR(nn.Module): + def __init__(self, + n_filters: int=64, + n_channel_input: int=3, + n_channel_output: int=3 + ): + super(LDR2HDR, self).__init__() + self.conv1 = nn.Conv2d(n_channel_input, n_filters, 4, 2, 1) + self.conv2 = nn.Conv2d(n_filters, n_filters * 2, 4, 2, 1) + self.conv3 = nn.Conv2d(n_filters * 2, n_filters * 4, 4, 2, 1) + self.conv4 = nn.Conv2d(n_filters * 4, n_filters * 8, 4, 2, 1) + self.conv5 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1) + self.conv6 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1) + self.conv7 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1) + self.conv8 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1) + + self.deconv1 = nn.ConvTranspose2d(n_filters * 8, n_filters * 8, 4, 2, 1) + self.deconv2 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1) + self.deconv3 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1) + self.deconv4 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1) + self.deconv5 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 4, 4, 2, 1) + self.deconv6 = nn.ConvTranspose2d(n_filters * 4 * 2, n_filters * 2, 4, 2, 1) + self.deconv7 = nn.ConvTranspose2d(n_filters * 2 * 2, n_filters, 4, 2, 1) + self.deconv8 = nn.ConvTranspose2d(n_filters * 2, n_channel_output, 4, 2, 1) + + self.batch_norm = nn.BatchNorm2d(n_filters) + self.batch_norm2 = nn.BatchNorm2d(n_filters * 2) + self.batch_norm4 = nn.BatchNorm2d(n_filters * 4) + self.batch_norm8 = nn.BatchNorm2d(n_filters * 8) + + self.leaky_relu = nn.LeakyReLU(0.2, True) + self.relu = nn.ReLU(True) + + self.dropout = nn.Dropout(0.5) + + self.tanh = nn.Tanh() + + def forward(self, input): + encoder1 = self.conv1(input) + encoder2 = self.batch_norm2(self.conv2(self.leaky_relu(encoder1))) + encoder3 = self.batch_norm4(self.conv3(self.leaky_relu(encoder2))) + encoder4 = self.batch_norm8(self.conv4(self.leaky_relu(encoder3))) + encoder5 = self.batch_norm8(self.conv5(self.leaky_relu(encoder4))) + encoder6 = self.batch_norm8(self.conv6(self.leaky_relu(encoder5))) + encoder7 = self.batch_norm8(self.conv7(self.leaky_relu(encoder6))) + encoder8 = self.conv8(self.leaky_relu(encoder7)) + + decoder1 = self.dropout(self.batch_norm8(self.deconv1(self.relu(encoder8)))) + decoder1 = torch.cat((decoder1, encoder7), 1) + decoder2 = self.dropout(self.batch_norm8(self.deconv2(self.relu(decoder1)))) + decoder2 = torch.cat((decoder2, encoder6), 1) + decoder3 = self.dropout(self.batch_norm8(self.deconv3(self.relu(decoder2)))) + decoder3 = torch.cat((decoder3, encoder5), 1) + decoder4 = self.batch_norm8(self.deconv4(self.relu(decoder3))) + decoder4 = torch.cat((decoder4, encoder4), 1) + decoder5 = self.batch_norm4(self.deconv5(self.relu(decoder4))) + decoder5 = torch.cat((decoder5, encoder3), 1) + decoder6 = self.batch_norm2(self.deconv6(self.relu(decoder5))) + decoder6 = torch.cat((decoder6, encoder2),1) + decoder7 = self.batch_norm(self.deconv7(self.relu(decoder6))) + decoder7 = torch.cat((decoder7, encoder1), 1) + decoder8 = self.deconv8(self.relu(decoder7)) + output = self.tanh(decoder8) + return output