Skip to content

Commit

Permalink
First Commit
Browse files Browse the repository at this point in the history
  • Loading branch information
gkitsasv committed May 13, 2020
1 parent cd244bf commit 81ff6fa
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 0 deletions.
110 changes: 110 additions & 0 deletions helpers/sh_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import math
import numpy as np
import time
import torch
import torchvision

#Convolve the SH coefficients with a low pass filter in spatial domain
def deringing(coeffs, window):
deringed_coeffs = torch.zeros_like(coeffs)
deringed_coeffs[:, 0] += coeffs[:, 0]
deringed_coeffs[:, 1:1 + 3] += \
coeffs[:, 1:1 + 3] * math.pow(math.sin(math.pi * 1.0 / window) / (math.pi * 1.0 / window), 4.0)
deringed_coeffs[:, 4:4 + 5] += \
coeffs[:, 4:4 + 5] * math.pow(math.sin(math.pi * 2.0 / window) / (math.pi * 2.0 / window), 4.0)
return deringed_coeffs

# Spherical harmonics functions
def P(l, m, x):
pmm = 1.0
if(m>0):
somx2 = np.sqrt((1.0-x)*(1.0+x))
fact = 1.0
for i in range(1,m+1):
pmm *= (-fact) * somx2
fact += 2.0

if(l==m):
return pmm * np.ones(x.shape)

pmmp1 = x * (2.0*m+1.0) * pmm

if(l==m+1):
return pmmp1

pll = np.zeros(x.shape)
for ll in range(m+2, l+1):
pll = ( (2.0*ll-1.0)*x*pmmp1-(ll+m-1.0)*pmm ) / (ll-m)
pmm = pmmp1
pmmp1 = pll

return pll

def factorial(x):
if(x == 0):
return 1.0
return x * factorial(x-1)

def K(l, m):
return np.sqrt( ((2 * l + 1) * factorial(l-m)) / (4*np.pi*factorial(l+m)) )

def SH(l, m, theta, phi):
sqrt2 = np.sqrt(2.0)
if(m==0):
if np.isscalar(phi):
return K(l,m)*P(l,m,np.cos(theta))
else:
return K(l,m)*P(l,m,np.cos(theta))*np.ones(phi.shape)
elif(m>0):
return sqrt2*K(l,m)*np.cos(m*phi)*P(l,m,np.cos(theta))
else:
return sqrt2*K(l,-m)*np.sin(-m*phi)*P(l,-m,np.cos(theta))

def shEvaluate(theta, phi, lmax):
if np.isscalar(theta):
coeffsMatrix = np.zeros((1,1,shTerms(lmax)))
else:
coeffsMatrix = np.zeros((theta.shape[0],phi.shape[0],shTerms(lmax)))

for l in range(0,lmax+1):
for m in range(-l,l+1):
index = shIndex(l, m)
coeffsMatrix[:,:,index] = SH(l, m, theta, phi)
return coeffsMatrix

def getCoeeficientsMatrix(xres,lmax=2):
yres = int(xres/2)
# setup fast vectorisation
x = np.arange(0,xres)
y = np.arange(0,yres).reshape(yres,1)

# Setup polar coordinates
latLon = xy2ll(x,y,xres,yres)

# Compute spherical harmonics. Apply thetaOffset due to EXR spherical coordiantes
Ylm = shEvaluate(latLon[0], latLon[1], lmax)
return Ylm

def shReconstructSignal(coeffs, sh_basis_matrix=None, width=512):
if sh_basis_matrix is None:
lmax = sh_lmax_from_terms(coeffs.shape[0])
sh_basis_matrix = getCoeeficientsMatrix(width,lmax)
sh_basis_matrix_t = torch.from_numpy(sh_basis_matrix).to(coeffs).float()
return (torch.matmul(sh_basis_matrix_t,coeffs)).to(coeffs).float()

def shTerms(lmax):
return (lmax + 1) * (lmax + 1)

def sh_lmax_from_terms(terms):
return int(np.sqrt(terms)-1)

def shIndex(l, m):
return l*l+l+m

def xy2ll(x,y,width,height):
def yLocToLat(yLoc, height):
return (yLoc / (float(height)/np.pi))
def xLocToLon(xLoc, width):
return (xLoc / (float(width)/(np.pi * 2)))
return np.asarray([yLocToLat(y, height), xLocToLon(x, width)])

Binary file added images/input.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
74 changes: 74 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import argparse
import os
import cv2
import sys
import numpy as np
import torch
import torchvision
from helpers.sh_functions import *
from loaders.Illum_loader import IlluminationModule, Inference_Data
from loaders.autoenc_ldr2hdr import LDR2HDR
from torch.utils.data import DataLoader

def parse_arguments(args):
usage_text = (
"Inference script for Deep Lighting Environment Map Estimation from Spherical Panoramas"
"Usage: python3 inference.py --input_path "
)
parser = argparse.ArgumentParser(description=usage_text)
parser.add_argument('--input_path', type=str, default='./images/input.jpg', help="Input panorama color image file")
parser.add_argument('--out_path', type=str, default='./output/', help='Output folder for the predicted environment map panorama')
parser.add_argument('-g','--gpu', type=str, default='0', help='GPU id of the device to use. Use -1 for CPU.')
parser.add_argument("--chkpnt_path", default='./models/model.pth', type=str, help='Pre-trained checkpoint file for lighting regression module')
parser.add_argument('--ldr2hdr_model', type=str, default='./models/ldr2hdr.pth', help='Pre-trained checkpoint file for ldr2hdr image translation module')
parser.add_argument("--width", type=float, default=512, help = "Spherical panorama image width.")
parser.add_argument('--deringing', type=int, default=0, help='Enable low pass deringing filter for the predicted SH coefficients')
parser.add_argument('--dr_window', type=float, default='6.0')
return parser.parse_known_args(args)

def evaluate(
illumination_module: torch.nn.Module,
ldr2hdr_module: torch.nn.Module,
args: argparse.Namespace,
device: torch.device
):
if (os.path.isdir(args.out_path)!=True):
os.mkdir(args.out_path)

in_filename, in_file_extention = os.path.splitext(args.input_path)
assert in_file_extention in ['.png','.jpg']
inference_data = Inference_Data(args.input_path)
out_path = args.out_path + os.path.basename(args.input_path)
out_filename, out_file_extension = os.path.splitext(out_path)
out_file_extension = '.exr'
out_path = out_filename + out_file_extension
dataloader = DataLoader(inference_data, batch_size=1, shuffle=False, num_workers=1)
for i, data in enumerate(dataloader):
input_img = data.to(device).float()
with torch.no_grad():
start_time = time.time()
right_rgb = ldr2hdr_module(input_img)
p_coeffs = illumination_module(right_rgb).view(1,9,3).to(device).float()
if args.deringing:
p_coeffs = deringing(p_coeffs, args.dr_window).to(device).float()
elapsed_time = time.time() - start_time
print("Elapsed inference time: %2.4fsec" % elapsed_time)
pred_env_map = shReconstructSignal(p_coeffs.squeeze(0), width=args.width)
cv2.imwrite(out_path, pred_env_map.cpu().detach().numpy())

def main(args):
device = torch.device("cuda:" + str(args.gpu) if (torch.cuda.is_available() and int(args.gpu) >= 0) else "cpu")
# load lighting module
illumination_module = IlluminationModule(batch_size=1).to(device)
illumination_module.load_state_dict(torch.load(args.chkpnt_path))
print("Lighting moduled loaded")
# load LDR2HDR module
ldr2hdr_module = LDR2HDR()
ldr2hdr_module.load_state_dict(torch.load(args.ldr2hdr_model)['state_dict_G'])
ldr2hdr_module = ldr2hdr_module.to(device)
print("LDR2HDR moduled loaded")
evaluate(illumination_module, ldr2hdr_module, args, device)

if __name__ == '__main__':
args, unknown = parse_arguments(sys.argv)
main(args)
61 changes: 61 additions & 0 deletions loaders/Illum_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from skimage import io, transform
import numpy as np
import cv2
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision import transforms, utils
'''
Input (256,512,3)
'''
class IlluminationModule(nn.Module):
def __init__(self, batch_size):
super().__init__()
self.cv_block1 = conv_bn_elu(3, 64, kernel_size=7, stride=2)
self.cv_block2 = conv_bn_elu(64, 128, kernel_size=5, stride=2)
self.cv_block3 = conv_bn_elu(128, 256, stride=2)
self.cv_block4 = conv_bn_elu(256, 256)
self.cv_block5 = conv_bn_elu(256, 256, stride=2)
self.cv_block6= conv_bn_elu(256, 256)
self.cv_block7 = conv_bn_elu(256, 256, stride=2)
self.fc = nn.Linear(256*16*8, 2048)
'''One head regression'''
self.sh_fc = nn.Linear(2048, 27)

def forward(self, x):
x = self.cv_block1(x)
x = self.cv_block2(x)
x = self.cv_block3(x)
x = self.cv_block4(x)
x = self.cv_block5(x)
x = self.cv_block6(x)
x = self.cv_block7(x)
x = x.view(-1, 256*8*16)
x = F.elu(self.fc(x))
return((self.sh_fc(x)))

def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True):
# conv layer with ELU activation function
pad = int(kernel_size/2)
if padding is False:
pad = 0
return nn.Sequential(
nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad),
nn.ELU(),
)

class Inference_Data(Dataset):
def __init__(self, img_path):
self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
self.input_img = cv2.resize(self.input_img, (512,256), interpolation=cv2.INTER_CUBIC)
self.to_tensor = transforms.ToTensor()
self.data_len = 1

def __getitem__(self, index):
self.tensor_img = self.to_tensor(self.input_img)
return self.tensor_img

def __len__(self):
return self.data_len
80 changes: 80 additions & 0 deletions loaders/autoenc_ldr2hdr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#Autoencoder for LDR to HDR image mapping

from torch import nn
import torch
from torchvision import models
import torchvision
from torch.nn import functional as F

def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)

class LDR2HDR(nn.Module):
def __init__(self,
n_filters: int=64,
n_channel_input: int=3,
n_channel_output: int=3
):
super(LDR2HDR, self).__init__()
self.conv1 = nn.Conv2d(n_channel_input, n_filters, 4, 2, 1)
self.conv2 = nn.Conv2d(n_filters, n_filters * 2, 4, 2, 1)
self.conv3 = nn.Conv2d(n_filters * 2, n_filters * 4, 4, 2, 1)
self.conv4 = nn.Conv2d(n_filters * 4, n_filters * 8, 4, 2, 1)
self.conv5 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1)
self.conv6 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1)
self.conv7 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1)
self.conv8 = nn.Conv2d(n_filters * 8, n_filters * 8, 4, 2, 1)

self.deconv1 = nn.ConvTranspose2d(n_filters * 8, n_filters * 8, 4, 2, 1)
self.deconv2 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1)
self.deconv3 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1)
self.deconv4 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 8, 4, 2, 1)
self.deconv5 = nn.ConvTranspose2d(n_filters * 8 * 2, n_filters * 4, 4, 2, 1)
self.deconv6 = nn.ConvTranspose2d(n_filters * 4 * 2, n_filters * 2, 4, 2, 1)
self.deconv7 = nn.ConvTranspose2d(n_filters * 2 * 2, n_filters, 4, 2, 1)
self.deconv8 = nn.ConvTranspose2d(n_filters * 2, n_channel_output, 4, 2, 1)

self.batch_norm = nn.BatchNorm2d(n_filters)
self.batch_norm2 = nn.BatchNorm2d(n_filters * 2)
self.batch_norm4 = nn.BatchNorm2d(n_filters * 4)
self.batch_norm8 = nn.BatchNorm2d(n_filters * 8)

self.leaky_relu = nn.LeakyReLU(0.2, True)
self.relu = nn.ReLU(True)

self.dropout = nn.Dropout(0.5)

self.tanh = nn.Tanh()

def forward(self, input):
encoder1 = self.conv1(input)
encoder2 = self.batch_norm2(self.conv2(self.leaky_relu(encoder1)))
encoder3 = self.batch_norm4(self.conv3(self.leaky_relu(encoder2)))
encoder4 = self.batch_norm8(self.conv4(self.leaky_relu(encoder3)))
encoder5 = self.batch_norm8(self.conv5(self.leaky_relu(encoder4)))
encoder6 = self.batch_norm8(self.conv6(self.leaky_relu(encoder5)))
encoder7 = self.batch_norm8(self.conv7(self.leaky_relu(encoder6)))
encoder8 = self.conv8(self.leaky_relu(encoder7))

decoder1 = self.dropout(self.batch_norm8(self.deconv1(self.relu(encoder8))))
decoder1 = torch.cat((decoder1, encoder7), 1)
decoder2 = self.dropout(self.batch_norm8(self.deconv2(self.relu(decoder1))))
decoder2 = torch.cat((decoder2, encoder6), 1)
decoder3 = self.dropout(self.batch_norm8(self.deconv3(self.relu(decoder2))))
decoder3 = torch.cat((decoder3, encoder5), 1)
decoder4 = self.batch_norm8(self.deconv4(self.relu(decoder3)))
decoder4 = torch.cat((decoder4, encoder4), 1)
decoder5 = self.batch_norm4(self.deconv5(self.relu(decoder4)))
decoder5 = torch.cat((decoder5, encoder3), 1)
decoder6 = self.batch_norm2(self.deconv6(self.relu(decoder5)))
decoder6 = torch.cat((decoder6, encoder2),1)
decoder7 = self.batch_norm(self.deconv7(self.relu(decoder6)))
decoder7 = torch.cat((decoder7, encoder1), 1)
decoder8 = self.deconv8(self.relu(decoder7))
output = self.tanh(decoder8)
return output

0 comments on commit 81ff6fa

Please sign in to comment.