Skip to content

Commit

Permalink
Merge pull request #20 from Schrodinger-Hat/feat/ai-conversion
Browse files Browse the repository at this point in the history
feat: add AI conversion with a new method + minore fixies
  • Loading branch information
TheJoin95 authored Apr 17, 2024
2 parents f010982 + 053c1ec commit 8f96059
Show file tree
Hide file tree
Showing 19 changed files with 969 additions and 85 deletions.
113 changes: 97 additions & 16 deletions ImageGoNord/GoNord.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,25 @@
import uuid
import shutil

import torch
import skimage.io as io
import skimage.color as convertor
import torchvision.transforms as transforms


try:
import importlib.resources as pkg_resources
except ImportError:
# Try backported to PY<37 `importlib_resources`.
import importlib_resources as pkg_resources

from .palettes import Nord as nord_palette
from .models import PaletteNet as palette_net

from ImageGoNord.utility.quantize import quantize_to_palette
import ImageGoNord.utility.palette_loader as pl
from ImageGoNord.utility.ConvertUtility import ConvertUtility
from ImageGoNord.utility.model import FeatureEncoder,RecoloringDecoder


class NordPaletteFile:
Expand Down Expand Up @@ -417,7 +425,70 @@ def converted_loop(self, is_rgba, pixels, original_pixels, maxRow, maxCol, minRo
pixels[row, col] = tuple(colors_list)
return pixels

def convert_image(self, image, save_path='', parallel_threading=False):
def convert_image_by_model(self, image, use_model_cpu=False):
"""
Process a Pillow image by using a PyTorch model "PaletteNet" for recoloring the image
Parameters
----------
image : pillow image
The source pillow image
use_model_cpu : bool, optional
true if using cpu power
Returns
-------
pillow image
processed image
"""
FE = FeatureEncoder() # torch.Size([64, 3, 3, 3])
RD = RecoloringDecoder() # torch.Size([530, 256, 3, 3])

FE.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "FE.state_dict.pt")))
RD.load_state_dict(torch.load(pkg_resources.open_binary(palette_net, "RD.state_dict.pt")))

if use_model_cpu:
FE.to("cpu")
RD.to("cpu")

lab_image = ((convertor.rgb2lab(np.array(image))) - [50,0,0] ) / [50,127,127]

img = torch.Tensor(lab_image).permute(2,0,1)

h = 16*int(img.shape[1]/16)
w = 16*int(img.shape[2]/16)

T = transforms.Resize((h,w))

img = T(img)
img = img.unsqueeze(0)
palette = []
for hex, rgb_value in self.PALETTE_DATA.items():
a = []
for j in [2,4,6]:
a.append(int(hex[j-2:j],16))
palette.append(a)

try:
pal_np = np.array(palette).reshape(1,6,3)/255
except:
print("You have too many colors in your palette for the model, this feature is limited to 6 colours, now you have: ", len(palette), "! I'll take the first 6!")
pal_np = np.array(palette[0:6]).reshape(1,6,3)/255

pal = torch.Tensor((convertor.rgb2lab(pal_np) - [50,0,0] ) / [50,128,128]).unsqueeze(0)

image = img
palette = pal
illu = image[:,0:1,:,:]

with torch.no_grad():
c1,c2,c3,c4 = FE(image)
out = RD(c1, c2, c3, c4, palette, illu)
final_image = torch.cat([(illu+1)*50, out*128],axis = 1).permute(0,2,3,1)[0]
# need to convert float value returning in skimage to 0-255 range values for pillow (computer vision / training lib vs pixel operation lib)
return Image.fromarray((convertor.lab2rgb(final_image) * 255).astype(np.uint8))

def convert_image(self, image, save_path='', use_model=False, use_model_cpu=False, parallel_threading=False):
"""
Process a Pillow image by replacing pixel or by avg algorithm
Expand All @@ -427,6 +498,12 @@ def convert_image(self, image, save_path='', parallel_threading=False):
The source pillow image
save_path : str, optional
the path and the filename where to save the image
use_model : bool, optional
true if using ai model
use_model_cpu : bool, optional
true if using cpu power
parallel_threading : bool, optional
true to enable multi-thread conversion loop
Returns
-------
Expand All @@ -439,22 +516,26 @@ def convert_image(self, image, save_path='', parallel_threading=False):
original_image.close()
pixels = self.load_pixel_image(image)
is_rgba = (image.mode == 'RGBA')
if (parallel_threading == False):
self.converted_loop(is_rgba, pixels, original_pixels, image.size[0], image.size[1])

if use_model:
image = self.convert_image_by_model(image, use_model_cpu)
else:
step = ceil(image.size[0] / self.MAX_THREADS)
threads = []
for row in range(step, image.size[0] + step, step):
args = (is_rgba, pixels, original_pixels, row, image.size[1], row - step, 0)
t = threading.Thread(target=self.converted_loop, args=args)
t.daemon = True
t.start()
threads.append(t)

for t in threads:
t.join(timeout=30)

if (self.USE_GAUSSIAN_BLUR == True):
if not parallel_threading:
self.converted_loop(is_rgba, pixels, original_pixels, image.size[0], image.size[1])
else:
step = ceil(image.size[0] / self.MAX_THREADS)
threads = []
for row in range(step, image.size[0] + step, step):
args = (is_rgba, pixels, original_pixels, row, image.size[1], row - step, 0)
t = threading.Thread(target=self.converted_loop, args=args)
t.daemon = True
t.start()
threads.append(t)

for t in threads:
t.join(timeout=30)

if self.USE_GAUSSIAN_BLUR:
image = image.filter(ImageFilter.GaussianBlur(1))

if (save_path != ''):
Expand Down
2 changes: 1 addition & 1 deletion ImageGoNord/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# gonord version
__version__ = "0.1.7"
__version__ = "1.0.0"

from ImageGoNord.GoNord import *
Binary file added ImageGoNord/models/PaletteNet/FE.state_dict.pt
Binary file not shown.
Binary file added ImageGoNord/models/PaletteNet/RD.state_dict.pt
Binary file not shown.
Empty file.
Empty file added ImageGoNord/models/__init__.py
Empty file.
141 changes: 141 additions & 0 deletions ImageGoNord/utility/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

device = "cuda" if torch.cuda.is_available() else "cpu"

class Conv2dAuto(nn.Conv2d):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.padding = (self.kernel_size[0] // 2, self.kernel_size[1] // 2) #dynamic add padding based on the kernel_size
conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)

def activation_func(activation): #Activation function as mentioned in the paper - Leaky Relu
return nn.ModuleDict([
['relu', nn.ReLU(inplace=True)],
['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
['none', nn.Identity()]
])[activation]


class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, activation='relu'):
super().__init__()
self.in_channels, self.out_channels,self.activation = in_channels, out_channels, activation
self.blocks = nn.Identity()
self.shortcut = nn.Identity()
self.activate = activation_func(activation)

def forward(self, x):
residual = x
if self.should_apply_shortcut: residual = self.shortcut(x)
x = self.blocks(x)
x += residual
x = self.activate(x)
return x

@property
def should_apply_shortcut(self):
return self.in_channels != self.out_channels

class ResNetResidualBlock(ResidualBlock):
def __init__(self, in_channels, out_channels, expansion=1, downsampling=2, conv=conv3x3, *args, **kwargs):
super().__init__(in_channels, out_channels)
self.expansion, self.downsampling, self.conv = expansion, downsampling, conv
self.shortcut = nn.Sequential(OrderedDict(
{
'conv' : nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1,
stride=self.downsampling, bias=False, padding=0),
'bn' : nn.InstanceNorm2d(self.expanded_channels)

})) if self.should_apply_shortcut else None

@property
def expanded_channels(self):
return self.out_channels * self.expansion

@property
def should_apply_shortcut(self):
return self.in_channels != self.expanded_channels

def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
return nn.Sequential(OrderedDict({'conv': conv(in_channels, out_channels, *args, **kwargs),
'bn': nn.InstanceNorm2d(out_channels) }))

class ResNetBasicBlock(ResNetResidualBlock):
expansion = 1
def __init__(self, in_channels, out_channels, activation=nn.LeakyReLU, *args, **kwargs):
super().__init__(in_channels, out_channels, *args, **kwargs)
self.blocks = nn.Sequential(
conv_bn(self.in_channels, self.out_channels,conv=self.conv, bias=False, stride=self.downsampling),
activation(negative_slope=0.02),
conv_bn(self.out_channels, self.expanded_channels,conv=self.conv, bias=False),
)

class FeatureEncoder(nn.Module):

def __init__(self,*args,**kwargs):
super(FeatureEncoder,self).__init__()

self.conv=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=1) #3xHxW
self.norm=nn.InstanceNorm2d(64)
self.pool=nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

self.res1 = ResNetBasicBlock(64, 128)
self.res2 = ResNetBasicBlock(128, 256)
self.res3 = ResNetBasicBlock(256, 512)

def forward(self, x):
x = F.relu(self.norm(self.conv(x)))
c4 = self.pool(x)
c3 = self.res1(c4)
c2 = self.res2(c3)
c1 = self.res3(c2)
return c1,c2,c3,c4

def de_conv(in_channels, out_channels,kernel_size=3): #deconvolution
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels,kernel_size=3,stride=2,output_padding=1, padding=1,bias=True),
nn.InstanceNorm2d(out_channels),
nn.LeakyReLU(negative_slope=0.02,inplace=True)
)

class RecoloringDecoder(nn.Module):

def __init__(self):
super().__init__()
self.dconv_up_4 = de_conv(18 + 512, 256) #pt,c1
self.dconv_up_3 = de_conv(256 + 256, 128) #c2,d1
self.dconv_up_2 = de_conv(18 + 128 + 128, 64) #pt,c3,d2
self.dconv_up_1 = de_conv(18 + 64 + 64, 64) #pt,c4,d3
self.conv_last = nn.Conv2d(1 + 64, 2, kernel_size=3,padding=1) #Illu,d4

def forward(self, c1, c2, c3, c4, target_palettes_1d, illu):
bz, h, w = c1.shape[0], c1.shape[2], c1.shape[3] #1,24,16
tp_reshpaed = target_palettes_1d.reshape(bz,18,1,1)
tp_c1 = tp_reshpaed.repeat(1,1,h,w)

x = torch.cat((c1,tp_c1), 1)
x = self.dconv_up_4(x)

x = torch.cat([c2, x], dim=1) #c2,d1(x)
x = self.dconv_up_3(x)

bz, h, w = x.shape[0], x.shape[2], x.shape[3]
tp_c3 = tp_reshpaed.repeat(1,1,h,w)
x = torch.cat([tp_c3,c3,x], dim=1) #Pt,c3,x
x = self.dconv_up_2(x)

bz, h, w = x.shape[0], x.shape[2], x.shape[3]
tp_c4 = tp_reshpaed.repeat(1,1,h,w)
x = torch.cat([tp_c4,c4,x], dim=1) #Pt,c4,x
x = self.dconv_up_1(x)

illu = illu.view(illu.size(0), 1, illu.size(2), illu.size(3))
x = torch.cat((x, illu), dim = 1)
#illu,x
x = self.conv_last(x)
x = torch.tanh(x)
return x
Loading

0 comments on commit 8f96059

Please sign in to comment.