Skip to content


Add editing models
Browse files Browse the repository at this point in the history
  • Loading branch information
jrrodri committed Oct 25, 2024
1 parent d4969a3 commit 9d08461
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 2 deletions.
Empty file added abraia/editing/
Empty file.
54 changes: 54 additions & 0 deletions abraia/editing/
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import cv2
import numpy as np
import onnxruntime as ort

from ..utils import download_file

def ceil_modulo(x, mod):
return x if x % mod == 0 else (x // mod + 1) * mod

def pad_img_to_modulo(img, mod):
channels, height, width = img.shape
out_height, out_width = ceil_modulo(height, mod), ceil_modulo(width, mod)
return np.pad(img, ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric")

class LAMA:

def __init__(self):
self.image_size = (512, 512)
sess_options = ort.SessionOptions()
model_src = download_file('multiple/models/editing/lama_fp32.onnx')
self.session = ort.InferenceSession(model_src, sess_options=sess_options)

def preprocess(self, image, mask, pad_out_to_modulo=8):
out_image = cv2.resize(image, self.image_size, interpolation=cv2.INTER_LINEAR)
out_mask = cv2.resize(mask, self.image_size, interpolation=cv2.INTER_NEAREST)

out_image = (out_image.transpose((2, 0, 1)) / 255)
out_mask = (out_mask[np.newaxis, ...] / 255)

if pad_out_to_modulo is not None and pad_out_to_modulo > 1:
out_image = pad_img_to_modulo(out_image, pad_out_to_modulo)
out_mask = pad_img_to_modulo(out_mask, pad_out_to_modulo)

out_mask = (out_mask > 0) * 1

out_image = np.expand_dims(out_image, axis=0).astype(np.float32)
out_mask = np.expand_dims(out_mask, axis=0).astype(np.float32)
return out_image, out_mask

def postprocess(self, output, size):
output = output[0].transpose(1, 2, 0).astype(np.uint8)
output = cv2.resize(output, size, interpolation=cv2.INTER_CUBIC)
return output

def predict(self, image, mask):
h, w = image.shape[:2]
image, mask = self.preprocess(image, mask)
outputs =, {'image': image, 'mask': mask})
output = self.postprocess(outputs[0], (w, h))
return output

62 changes: 62 additions & 0 deletions abraia/editing/
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import cv2
import numpy as np
import onnxruntime as ort

from ..utils import download_file


kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))

def post_process(mask):
Post Process the mask for a smooth boundary by applying Morphological Operations
Research based on paper:
mask: Binary Numpy Mask
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
mask = cv2.GaussianBlur(mask, (5, 5), sigmaX=2, sigmaY=2, borderType=cv2.BORDER_DEFAULT)
mask = np.where(mask < 127, 0, 255).astype(np.uint8)
return mask

class RemoveBG:

def __init__(self):
self.image_size = (1024, 1024)
self.input_mean = (0.485, 0.456, 0.406)
self.providers = ort.get_available_providers()
model_src = download_file('multiple/models/editing/isnet-general-use.onnx')
self.session = ort.InferenceSession(model_src, providers=self.providers)
self.input_name = self.session.get_inputs()[0].name

def preprocess(self, img):
img = cv2.resize(img, self.image_size, interpolation=cv2.INTER_LINEAR)
img = img / np.max(img) - np.array(self.input_mean)
img = img.transpose((2, 0, 1)).astype(np.float32)
return np.expand_dims(img, axis=0)

def postprocess(self, out, size):
pred = out.reshape(self.image_size)
ma, mi = np.max(pred), np.min(pred)
pred = (pred - mi) / (ma - mi)
mask = (pred * 255).astype(np.uint8)
mask = cv2.resize(mask, size, interpolation=cv2.INTER_CUBIC)
return mask

def predict(self, img):
h, w = img.shape[:2]
inputs = {self.input_name: self.preprocess(img)}
outputs =, inputs)
mask = self.postprocess(outputs[0], (w, h))
return mask

def remove(self, img, post_process_mask = False):
mask = self.predict(img)
if post_process_mask:
mask = post_process(mask)
img = cv2.cvtColor(img, cv2.COLOR_RGB2RGBA)
img[:, :, 3] = mask
return img
89 changes: 89 additions & 0 deletions abraia/editing/
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import cv2
import json
import numpy as np
import onnxruntime as ort

from ..utils import download_file

def get_input_points(prompt):
prompt = json.loads(prompt)
points, labels = [], []
for mark in prompt:
if mark["type"] == "point":
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]])
points.append([mark["data"][2], mark["data"][3]])
return np.array(points), np.array(labels)

class SAM:

def __init__(self):
self.target_size = 1024
self.input_size = (684, 1024)
sess_options = ort.SessionOptions()
providers = ort.get_available_providers()
encoder_src = download_file('multiple/models/mobile_sam.encoder.onnx')
decoder_src = download_file('multiple/models/mobile_sam.decoder.onnx')
self.encoder = ort.InferenceSession(encoder_src, providers=providers, sess_options=sess_options)
self.decoder = ort.InferenceSession(decoder_src, providers=providers, sess_options=sess_options)

def encode(self, img):
encoder_input_name = self.encoder.get_inputs()[0].name
encoder_inputs = {encoder_input_name: img.astype(np.float32)}
encoder_output =, encoder_inputs)
image_embedding = encoder_output[0]
return image_embedding

def decode(self, image_embedding, input_points, input_labels):
onnx_coord = np.concatenate([input_points, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = np.concatenate([onnx_coord, np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32)], axis=2)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)

onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

decoder_inputs = {"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(self.input_size, dtype=np.float32)}
masks, _, _ =, decoder_inputs)
return masks[0]

def predict(self, img, prompt="[]"):
height, width = img.shape[:2]

scale_x = self.input_size[1] / img.shape[1]
scale_y = self.input_size[0] / img.shape[0]
scale = min(scale_x, scale_y)

transform_matrix = np.array([[scale, 0, 0],
[0, scale, 0],
[0, 0, 1]])

size = (self.input_size[1], self.input_size[0])
img = cv2.warpAffine(img, transform_matrix[:2], size, flags=cv2.INTER_LINEAR)
image_embedding = self.encode(img)

# embedding = {"image_embedding": image_embedding,
# "original_size": (width, height),
# "transform_matrix": transform_matrix}

input_points, input_labels = get_input_points(prompt)
input_points = input_points * scale
masks = self.decode(image_embedding, input_points, input_labels)

inv_transform_matrix = np.linalg.inv(transform_matrix)
mask = np.zeros((height, width, 3), dtype=np.uint8)
for m in masks:
m = cv2.warpAffine(m, inv_transform_matrix[:2], (width, height), flags=cv2.INTER_LINEAR)
mask[m > 0.0] = [255, 255, 255]
return mask
124 changes: 124 additions & 0 deletions abraia/editing/
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import cv2
import numpy as np
import onnxruntime as ort

from ..utils import download_file

class GPEN:

def __init__(self):
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
model_src = download_file('multiple/models/editing/GPEN-BFR-256.onnx')
self.session = ort.InferenceSession(model_src, sess_options=session_options)
self.input_name = self.session.get_inputs()[0].name
self.image_size = self.session.get_inputs()[0].shape[-2:]

def preprocess(self, img):
img = cv2.resize(img, self.image_size, interpolation=cv2.INTER_LINEAR)
img = ((img / 255 - 0.5) / 0.5).astype(np.float32)
return np.expand_dims(img.transpose((2, 0, 1)), axis=0)

def postprocess(self, img):
img = (255 * ((img.clip(-1, 1) + 1) * 0.5)).astype(np.uint8)
return img.transpose((1, 2, 0))

def enhance(self, img):
inputs = {self.input_name: self.preprocess(img)}
outputs =, inputs)
return self.postprocess(outputs[0][0])

class SwinIR:

def __init__(self):
model_src = download_file('multiple/models/editing/003_realSR_BSRGAN_DFO_s64w8_SwinIR-M_x4_GAN.onnx')
self.session = ort.InferenceSession(model_src)
self.input_name = self.session.get_inputs()[0].name

def preprocess(self, img):
img = (img / 255).astype(np.float32)
return np.expand_dims(img.transpose((2, 0, 1)), axis=0)

def postprocess(self, img):
img = (255 * img.clip(0, 1)).astype(np.uint8)
return img.transpose((1, 2, 0))

def upscale(self, img):
inputs = {self.input_name: self.preprocess(img)}
outputs =, inputs)
return self.postprocess(outputs[0][0])

def create_gradient_mask(shape, feather):
"""Create a gradient mask for smooth blending of tiles."""
mask = np.ones(shape)
_, _, h, w = shape
for feather_step in range(feather):
factor = (feather_step + 1) / feather
mask[:, :, feather_step, :] *= factor
mask[:, :, h - 1 - feather_step, :] *= factor
mask[:, :, :, feather_step] *= factor
mask[:, :, :, w - 1 - feather_step] *= factor
return mask

def tiled_upscale(samples, function, scale, tile_size, overlap = 8):
"""Apply a scaling function to image samples in a tiled manner."""
tile_width, tile_height = tile_size
_batch, _channels, height, width = samples.shape
out_height, out_width = round(height * scale), round(width * scale)
# Initialize output tensors
output = np.empty((1, 3, out_height, out_width))
out = np.zeros((1, 3, out_height, out_width))
out_div = np.zeros_like(output)
# Process the image in tiles
for y in range(0, height, tile_height - overlap):
for x in range(0, width, tile_width - overlap):
# Ensure we don't go out of bounds
x_end = min(x + tile_width, width)
y_end = min(y + tile_height, height)
x = max(0, x_end - tile_width)
y = max(0, y_end - tile_height)
# Extract and process the tile
tile = samples[:, :, y:y_end, x:x_end]
processed_tile = function(tile)
# Calculate the position in the output tensor
out_y, out_x = round(y * scale), round(x * scale)
out_h, out_w = processed_tile.shape[2:]
# Create a feathered mask for smooth blending
mask = create_gradient_mask(processed_tile.shape, overlap * scale)
# Add the processed tile to the output
out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask
out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask
# Normalize the output
output = out / out_div
return output

class Upscaler:

def __init__(self, model_path = '', scale = 4, tile_size = (1024, 1024), overlap = 8):
self.scale = scale
self.overlap = overlap
self.tile_size = tile_size
model_src = download_file(model_path or 'multiple/models/editing/4xNomosWebPhoto_RealPLKSR_fp32_opset17.onnx')
self.session = ort.InferenceSession(model_src)
self.input_name = self.session.get_inputs()[0].name

def preprocess(self, img):
img = (img / 255).astype(np.float32)
return np.expand_dims(img.transpose((2, 0, 1)), axis=0)

def postprocess(self, img):
img = (255 * img.clip(0, 1)).astype(np.uint8)
return img.transpose((1, 2, 0))

def predict(self, img):
return, {self.input_name: img})[0]

def upscale(self, img):
img = self.preprocess(img)
outputs = tiled_upscale(img, self.predict, self.scale, self.tile_size, self.overlap)
return self.postprocess(outputs[0])
Binary file added images/jacket.jpeg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion notebooks/hyperspectral-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"source": [
"!python -m pip install abraia==0.12.1\n",
"!python -m pip install abraia\n",
"import os\n",
"if not os.getenv('ABRAIA_ID') or not os.getenv('ABRAIA_KEY'):\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/hyperspectral-classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"source": [
"!python -m pip install abraia==0.12.1\n",
"!python -m pip install abraia\n",
"import os\n",
"if not os.getenv('ABRAIA_ID') or not os.getenv('ABRAIA_KEY'):\n",
Expand Down

0 comments on commit 9d08461

Please sign in to comment.