Skip to content

Preprocessing ICC conversion node, robust ONNX tiling with cosine feather blending, and [-1, 1] normalization support (with tested code) #3158

@zelenooki87

Description

@zelenooki87

Summary

  • Add a preprocessing node to convert input images to a desired ICC profile before any model execution. If full ICC handling isn’t feasible right away, at least force-convert to the system sRGB ICC profile to avoid desaturation with wide-gamut sources.
  • Improve ONNX tiling to match or exceed PyTorch-quality tiling:
    • Reflect padding, cosine/Hann feather masks, weighted blending, and weight normalization.
    • Works with dynamic ONNX models (no requirement to hard-code static input sizes).
  • Add automatic model normalization detection (Diffusion [-1, 1] vs GAN [0, 1]), with auto-cast to the model’s required input dtype.

Why this matters

  • Without ICC handling, feeding AdobeRGB/ProPhoto/etc. into a pipeline assuming sRGB causes desaturation and color mismatch.
  • ONNX tiling quality is currently inferior to PyTorch (visible seams, edge artifacts). A cosine-feathered weighted compositor removes seams and improves quality.
  • Diffusion-based ONNX upscalers typically expect inputs normalized to [-1, 1], while GAN-style models use [0, 1]. Automatic detection makes nodes more robust and user-friendly.

Proposed implementation (code you can drop into a ChaiNNer node)

A) Tiling for dynamic ONNX models with cosine feather blending

  • Works for arbitrarily large images by processing overlapping tiles.
  • Uses reflect padding for edge tiles.
  • Uses cosine/Hann-like feather masks and weighted averaging to eliminate seams.
  • Returns float32 RGB in [0, 1].
  • This snippet assumes you already know the model scale (e.g., 2x, 3x, 4x). If needed, scale can be inferred by a single warm-up run comparing output vs input tile size.
# tiling_dynamic.py
# High-quality tiling compositor for dynamic ONNX models (no assumption about fixed input shapes).
# Implements: reflect padding, cosine feather masks, weighted blending, weight normalization.
# Returns float32 RGB in [0,1]. Works with an injected "infer_batch_fn" that calls ORT session.run().

from typing import List, Tuple, Callable
import numpy as np

class OptimizedFeatherMaskCache:
    """
    LRU-ish cache of feather masks keyed by (h, w, feather_pixels), with an approximate memory budget.
    Masks are float32 arrays in [0,1]. Cosine/Hann-like ramps are applied on all four edges.
    """
    def __init__(self, max_memory_mb: float = 128):
        self.cache = {}
        self.access_count = {}
        self.max_memory_bytes = int(max_memory_mb * 1024 * 1024)
        self.current_memory = 0

    def get(self, h: int, w: int, feather: int) -> np.ndarray:
        key = (h, w, feather)
        if key in self.cache:
            self.access_count[key] += 1
            return self.cache[key]
        mask = self._create_mask(h, w, feather)
        sz = mask.nbytes
        # Evict least-used masks until we fit
        while self.current_memory + sz > self.max_memory_bytes and self.cache:
            self._evict_lru()
        self.cache[key] = mask
        self.access_count[key] = 1
        self.current_memory += sz
        return mask

    def _evict_lru(self):
        lru = min(self.access_count, key=self.access_count.get)
        arr = self.cache.pop(lru)
        self.access_count.pop(lru)
        self.current_memory -= arr.nbytes

    @staticmethod
    def _create_mask(h: int, w: int, feather: int) -> np.ndarray:
        """
        Build a 2D mask of ones with cosine fades at the borders.
        The fade width (feather) is clamped to h/2, w/2 to avoid degeneracy on tiny tiles.
        """
        if feather <= 0:
            return np.ones((h, w), dtype=np.float32)

        mask = np.ones((h, w), dtype=np.float32)
        fh = min(feather, h // 2)
        fw = min(feather, w // 2)

        # Vertical cosine ramps (top/bottom)
        if fh > 0:
            t = np.linspace(0, np.pi, fh, endpoint=False, dtype=np.float32)
            ramp_h = 0.5 * (1.0 - np.cos(t))  # smooth monotonic fade-in
            mask[:fh] *= ramp_h[:, None]
            mask[-fh:] *= ramp_h[::-1, None]

        # Horizontal cosine ramps (left/right)
        if fw > 0:
            t = np.linspace(0, np.pi, fw, endpoint=False, dtype=np.float32)
            ramp_w = 0.5 * (1.0 - np.cos(t))
            mask[:, :fw] *= ramp_w[None, :]
            mask[:, -fw:] *= ramp_w[None, ::-1]

        return mask

    def clear(self):
        self.cache.clear()
        self.access_count.clear()
        self.current_memory = 0


class AdvancedTileProcessor:
    """
    Generates a tiled cover for an input image and extracts tiles with reflect padding.
    scale: model's integer scale factor (e.g., 2/3/4).
    """
    def __init__(self, tile_size: int, overlap: int, scale: int):
        if tile_size <= 0:
            raise ValueError("tile_size must be > 0 for tiling")
        if overlap <= 0 or overlap >= tile_size:
            raise ValueError("overlap must be in (0, tile_size)")
        if scale < 1:
            raise ValueError("scale must be >= 1")

        self.tile_size = tile_size
        self.overlap = overlap
        self.scale = scale
        self.out_tile = self.tile_size * self.scale
        self.mask_cache = OptimizedFeatherMaskCache(128)

    def generate_tiles(self, h: int, w: int) -> List[Tuple[int, int]]:
        """
        Build a grid of top-left positions. The stride is (tile_size - overlap).
        The last tile along each axis is anchored at (dim - tile_size) to fully cover the image.
        """
        stride = max(1, self.tile_size - self.overlap)
        xs = [0] if w <= self.tile_size else list(range(0, w - self.tile_size + 1, stride))
        ys = [0] if h <= self.tile_size else list(range(0, h - self.tile_size + 1, stride))
        if xs and xs[-1] < w - self.tile_size:
            xs.append(w - self.tile_size)
        if ys and ys[-1] < h - self.tile_size:
            ys.append(h - self.tile_size)
        return [(x, y) for y in ys for x in xs]

    def extract_tiles(self, image: np.ndarray, positions: List[Tuple[int, int]]) -> np.ndarray:
        """
        Extract tiles and reflect-pad bottom/right edges to always produce tiles of shape (tile_size, tile_size).
        image: HxWxC (C=3). dtype can be uint8/uint16/float32.
        """
        tiles = []
        for x, y in positions:
            t = image[y:y + self.tile_size, x:x + self.tile_size]
            if t.shape[0] < self.tile_size or t.shape[1] < self.tile_size:
                pad_h = self.tile_size - t.shape[0]
                pad_w = self.tile_size - t.shape[1]
                t = np.pad(t, ((0, pad_h), (0, pad_w), (0, 0)), mode='reflect')
            tiles.append(t)
        return np.stack(tiles, axis=0)


def process_rgb_tiled_dynamic(
    img: np.ndarray,                   # Input RGB image, HxWx3. uint8/uint16, or float32 in [0,1].
    infer_batch_fn: Callable[[np.ndarray], np.ndarray],  # Callable that runs the model on BCHW input and returns float32 [B,3,Hout,Wout]
    tile_size: int = 256,
    overlap: int = 64,
    scale: int = 4,                    # Integer upscale factor of the model (e.g., 4 for 4x).
    batch_size: int = 1,
    norm_type: str = 'gan',            # 'gan' for [0,1], 'diffusion' for [-1,1].
) -> np.ndarray:
    """
    Run tiled inference with cosine feather blending and weight normalization.
    Returns float32 RGB in [0,1].

    Notes:
    - overlap should be tuned per model. Typical values: tile_size/4 to tile_size/2.
    - feather width in the output domain is derived from overlap: about half of (overlap * scale).
    - For best stability, choose tile_size so that receptive-field margins are covered by overlap (avoid seam fingerprints).
    """
    if img.ndim != 3 or img.shape[2] != 3:
        raise ValueError("RGB image must have shape HxWx3")

    # Generate tile grid and allocate output/weights
    h, w = img.shape[:2]
    tiler = AdvancedTileProcessor(tile_size, overlap, scale)
    positions = tiler.generate_tiles(h, w)

    out_h, out_w = h * scale, w * scale
    output = np.zeros((out_h, out_w, 3), dtype=np.float32)
    weights = np.zeros((out_h, out_w), dtype=np.float32)

    # Feather width in upscaled domain. Half overlap works well; adjust if needed.
    feather = max(8, (tiler.overlap * tiler.scale) // 2)
    full_mask = tiler.mask_cache.get(tiler.out_tile, tiler.out_tile, feather)

    # Process tiles in batches
    for i in range(0, len(positions), batch_size):
        batch_pos = positions[i:i + batch_size]
        batch = tiler.extract_tiles(img, batch_pos)

        # Convert to float32 in [0,1]
        if batch.dtype == np.uint8:
            b = batch.astype(np.float32) / 255.0
        elif batch.dtype == np.uint16:
            b = batch.astype(np.float32) / 65535.0
        else:
            b = np.clip(batch.astype(np.float32), 0.0, 1.0)

        batch_bchw = np.transpose(b, (0, 3, 1, 2))  # NHWC -> BCHW

        # Apply normalization expected by the model
        if norm_type == 'diffusion':
            batch_bchw = batch_bchw * 2.0 - 1.0

        batch_bchw = np.ascontiguousarray(batch_bchw, dtype=np.float32)

        # Model inference: returns [B,3,Hout,Wout] float32
        out = infer_batch_fn(batch_bchw)

        # Reorder to NHWC for compositing
        if out.ndim == 4 and out.shape[1] == 3:
            out = np.transpose(out, (0, 2, 3, 1))
        else:
            raise RuntimeError("Unexpected model output shape; expected [B,3,H,W]")

        # Denormalize back to [0,1]
        if norm_type == 'diffusion':
            out = (out + 1.0) * 0.5
        out = np.clip(out, 0.0, 1.0)

        # Weighted accumulation with feather masks
        for j, (x, y) in enumerate(batch_pos):
            ox, oy = x * scale, y * scale
            th = min(tiler.out_tile, out_h - oy)  # crop for bottom edge
            tw = min(tiler.out_tile, out_w - ox)  # crop for right edge
            mask = full_mask if (th == tiler.out_tile and tw == tiler.out_tile) else \
                   tiler.mask_cache.get(th, tw, feather)

            tile = out[j][:th, :tw].astype(np.float32)
            output[oy:oy+th, ox:ox+tw] += tile * mask[:, :, None]
            weights[oy:oy+th, ox:ox+tw] += mask

        # Optional: progress print
        # print(f"Tiles: {i + len(batch_pos)}/{len(positions)}", end="\r")

    # Normalize by accumulated weights to avoid seam brightness shifts
    np.maximum(weights, 1e-8, out=weights)
    output /= weights[:, :, None]
    np.clip(output, 0.0, 1.0, out=output)
    return output

B) Automatic normalization detection (Diffusion [-1, 1] vs GAN [0, 1])

  • Runs one tiny BCHW sample (in [0,1]) through the model twice: once as-is and once remapped to [-1,1].
  • Based on output ranges, heuristically picks “diffusion” or “gan”.
  • If the model input isn’t floating-point, we assume GAN-style [0,1].
# normalization_detection.py
# Automatic detection of model input normalization: diffusion ([-1,1]) vs gan ([0,1]).
# Provide a small BCHW sample in [0,1] (e.g., a single tile or center crop).

import numpy as np

def _to_numpy_dtype(onnx_type: str):
    if not onnx_type:
        return None
    t = onnx_type.lower()
    return {
        'tensor(float16)': np.float16,
        'tensor(float)':  np.float32,
        'tensor(double)': np.float64,
        'tensor(uint8)':  np.uint8,
        'tensor(int8)':   np.int8,
        'tensor(int32)':  np.int32,
        'tensor(int64)':  np.int64,
    }.get(t, None)

def detect_normalization_diffusion_or_gan(
    session, input_name: str, output_name: str, test_bchw_f32_in_0_1: np.ndarray
) -> str:
    """
    session: onnxruntime.InferenceSession
    input_name / output_name: tensor names
    test_bchw_f32_in_0_1: float32 BCHW sample in [0,1], shape [1,3,h,w] recommended

    Returns: 'diffusion' or 'gan'
    """
    # Inspect input dtype
    try:
        inp = session.get_inputs()[0]
        input_dtype = _to_numpy_dtype(getattr(inp, 'type', None))
    except Exception:
        input_dtype = np.float32

    # Non-float models typically expect [0,1]
    if input_dtype not in (np.float16, np.float32, np.float64):
        # print("Non-float input dtype; assuming GAN normalization [0,1]")
        return 'gan'

    try:
        test01 = np.ascontiguousarray(test_bchw_f32_in_0_1.astype(np.float32))
        test11 = np.ascontiguousarray(test01 * 2.0 - 1.0)

        # Cast to model input dtype
        test01 = test01.astype(input_dtype, copy=False)
        test11 = test11.astype(input_dtype, copy=False)

        out_01 = session.run([output_name], {input_name: test01})[0]
        out_11 = session.run([output_name], {input_name: test11})[0]

        o01 = np.asarray(out_01, dtype=np.float32)
        o11 = np.asarray(out_11, dtype=np.float32)

        min_01, max_01 = float(o01.min()), float(o01.max())
        min_11, max_11 = float(o11.min()), float(o11.max())

        # Heuristic: diffusion outputs tend to live within [-1.5, 1.5] and reach sub-zero;
        # GAN outputs usually stay within [-0.1, 1.1] when fed [0,1] inputs.
        if min_11 >= -1.5 and max_11 <= 1.5 and min_11 < -0.5:
            return 'diffusion'
        if min_01 >= -0.2 and max_01 <= 1.2 and min_01 >= -0.1:
            return 'gan'
        return 'diffusion'
    except Exception:
        # Conservative default that works well for diffusion upscalers
        return 'diffusion'

Minimal example (for testing end-to-end)

  • This shows how to plug both snippets into an ONNXRuntime inference loop for a dynamic model.
# example_dynamic_tiled_usage.py
# pip install onnxruntime pillow numpy
import numpy as np
from PIL import Image
import onnxruntime as ort

from tiling_dynamic import process_rgb_tiled_dynamic
from normalization_detection import detect_normalization_diffusion_or_gan

def load_rgb_u8(path: str) -> np.ndarray:
    return np.array(Image.open(path).convert('RGB'), dtype=np.uint8)

def save_rgb_u8(path: str, rgb_0_1_f32: np.ndarray):
    arr = (np.clip(rgb_0_1_f32, 0.0, 1.0) * 255.0 + 0.5).astype(np.uint8)
    Image.fromarray(arr, 'RGB').save(path)

# 1) Create ORT session (CUDA/DML if available, else CPU)
sess = ort.InferenceSession(
    "path/to/your_dynamic_model.onnx",
    providers=[p for p in ["CUDAExecutionProvider", "DmlExecutionProvider", "CPUExecutionProvider"]
               if p in ort.get_available_providers()]
)
inp = sess.get_inputs()[0]; out = sess.get_outputs()[0]
input_name, output_name = inp.name, out.name

# 2) Detect normalization using a small center crop
img = load_rgb_u8("input.jpg")
h, w = img.shape[:2]
cy, cx = max(0, h//2 - 64), max(0, w//2 - 64)
crop = img[cy:cy+128, cx:cx+128]
if crop.size == 0:
    crop = img[:min(128, h), :min(128, w)]
sample = crop.astype(np.float32) / 255.0
sample_bchw = np.transpose(sample[None, ...], (0, 3, 1, 2))
norm_type = detect_normalization_diffusion_or_gan(sess, input_name, output_name, sample_bchw)

# 3) Define an infer function that feeds BCHW to ORT and returns float32 [B,3,Hout,Wout]
def infer_batch_fn(bchw: np.ndarray) -> np.ndarray:
    # Cast to model input dtype if needed (ORT will accept float32 for float16 inputs as well, but explicit cast is fine)
    input_dtype = np.float32
    try:
        t = inp.type.lower()
        if "float16" in t:
            input_dtype = np.float16
        elif "float" in t:
            input_dtype = np.float32
    except Exception:
        pass
    bchw = bchw.astype(input_dtype, copy=False)
    out = sess.run([output_name], {input_name: bchw})[0]
    return np.asarray(out, dtype=np.float32)

# 4) Run tiled processing (set scale according to your model, e.g., 4 for 4x upscaler)
out = process_rgb_tiled_dynamic(
    img,
    infer_batch_fn=infer_batch_fn,
    tile_size=256,
    overlap=64,
    scale=4,
    batch_size=2,
    norm_type=norm_type,
)
save_rgb_u8("output.png", out)
print("Saved output.png")

C) sRGB ICC profile discovery and conversion

  • If full ICC management is out of scope, please at least detect and use a reliable system sRGB ICC for forced conversion. Otherwise, wide-gamut inputs will look desaturated after processing in an assumed-sRGB pipeline.
  • The function below enumerates common locations for sRGB profiles on Windows/macOS/Linux. It then attempts to open the profile via Pillow’s ImageCms. If nothing is found, we fall back to a built-in sRGB via ImageCms.createProfile("sRGB").

Explanatory notes

  • On Windows, color profiles are typically installed under %WINDIR%\System32\spool\drivers\color.
  • On macOS, system and user profiles live under ColorSync/Profiles.
  • On Linux, distributions often install profiles under /usr/share/color/icc and related subfolders (colord, argyllcms, ghostscript).
  • If you cannot find a system .icc/.icm file, Pillow’s createProfile("sRGB") returns a standard sRGB v2 profile, good enough for embedding and conversions.
  • When embedding ICC into saved files, pass icc_profile=bytes to Pillow’s save(). When converting, use ImageCms.buildTransform with perceptual intent and black-point compensation.
# srgb_profile_finder.py
# Robust discovery of a system sRGB ICC profile, with a built-in fallback.
# Requires: Pillow (PIL.ImageCms)

import os
import io
from typing import List, Optional, Tuple
from PIL import ImageCms, Image

def candidate_srgb_paths() -> List[str]:
    """
    Enumerate likely sRGB ICC/ICM locations across Windows/macOS/Linux.
    We prioritize well-known v2 profiles. v4 profiles (e.g., sRGB_v4_ICC_preference.icc) are listed as backups.
    """
    paths: List[str] = []

    # Windows
    windir = os.environ.get("WINDIR") or os.environ.get("SystemRoot")
    if windir:
        color_dir = os.path.join(windir, "System32", "spool", "drivers", "color")
        paths.extend([
            os.path.join(color_dir, "sRGB Color Space Profile.icm"),
            os.path.join(color_dir, "sRGB IEC61966-2.1.icm"),
            os.path.join(color_dir, "sRGB IEC61966-2.1.icc"),
            os.path.join(color_dir, "IEC61966-2.1.icm"),
            os.path.join(color_dir, "IEC61966-2.1.icc"),
            # v4 as backups
            os.path.join(color_dir, "sRGB_v4_ICC_preference.icc"),
            os.path.join(color_dir, "sRGB_v4_ICC_preference.icm"),
        ])

    # macOS (system, local, and user)
    paths.extend([
        "/System/Library/ColorSync/Profiles/sRGB Profile.icc",
        "/Library/ColorSync/Profiles/sRGB Profile.icc",
        os.path.expanduser("~/Library/ColorSync/Profiles/sRGB Profile.icc"),
    ])

    # Linux and Unix-like
    paths.extend([
        "/usr/share/color/icc/sRGB.icc",
        "/usr/share/color/icc/colord/sRGB.icc",
        "/usr/share/color/icc/argyllcms/sRGB.icc",
        "/usr/local/share/color/icc/sRGB.icc",
        "/usr/share/color/icc/ghostscript/srgb.icc",
    ])

    # De-duplicate while preserving order
    seen = set()
    uniq_paths = []
    for p in paths:
        if p not in seen:
            uniq_paths.append(p)
            seen.add(p)
    return uniq_paths

def load_system_srgb_profile() -> Tuple[Optional[ImageCms.ImageCmsProfile], Optional[bytes], Optional[str]]:
    """
    Try to open a system sRGB ICC/ICM profile from common paths.
    Returns (profile_object, profile_bytes, path_used).
    If not found, returns (fallback_profile, fallback_bytes, None) using ImageCms.createProfile('sRGB').
    """
    for p in candidate_srgb_paths():
        if os.path.exists(p):
            try:
                prof = ImageCms.getOpenProfile(p)
                # Prefer raw file bytes; they’re stable for embedding.
                with open(p, "rb") as f:
                    prof_bytes = f.read()
                return prof, prof_bytes, p
            except Exception:
                # Try next candidate
                pass

    # Fallback: Pillow’s built-in sRGB profile
    try:
        prof = ImageCms.createProfile("sRGB")
        # Pillow's ImageCmsProfile supports tobytes() for embedding in recent versions
        prof_bytes = None
        try:
            prof_bytes = prof.tobytes()
        except Exception:
            # Not critical; embedding can be skipped or replaced later with a file-based profile
            pass
        return prof, prof_bytes, None
    except Exception:
        # Absolute worst-case: no profile object. The caller should handle None.
        return None, None, None

def convert_image_to_srgb(img: Image.Image, intent: int = None) -> Image.Image:
    """
    Convert a PIL image from its embedded ICC (if present) to sRGB using perceptual intent and BPC.
    If no embedded ICC is present, returns the image unchanged (assumed to be sRGB already).
    """
    icc_bytes = img.info.get("icc_profile")
    if not icc_bytes:
        return img  # No embedded profile, likely already sRGB

    try:
        src_prof = ImageCms.getOpenProfile(io.BytesIO(icc_bytes))

        # Load or create sRGB destination
        dst_prof, _, _ = load_system_srgb_profile()
        if dst_prof is None:
            # Fallback if profile discovery failed: generate sRGB
            dst_prof = ImageCms.createProfile("sRGB")

        # Resolve perceptual intent and black-point compensation
        if intent is None:
            try:
                intent = ImageCms.Intent.PERCEPTUAL
            except Exception:
                intent = getattr(ImageCms, "INTENT_PERCEPTUAL", 0)
        try:
            flags = ImageCms.Flags.BLACKPOINTCOMPENSATION
        except Exception:
            flags = getattr(ImageCms, "FLAGS_BLACKPOINTCOMPENSATION", 0)

        # Short-circuit if already sRGB by description
        try:
            desc = ImageCms.getProfileDescription(src_prof).lower()
            if "srgb" in desc:
                return img
        except Exception:
            pass

        transform = ImageCms.buildTransform(
            src_prof, dst_prof, img.mode, img.mode, intent, flags
        )
        out = ImageCms.applyTransform(img, transform, inPlace=False)
        # Remove the old embedded profile; the saver can embed sRGB again if desired
        if 'icc_profile' in out.info:
            del out.info['icc_profile']
        return out
    except Exception:
        # Fail-safe: return original image unmodified
        return img

How to test that it works

  • ICC conversion:
    • Prepare AdobeRGB and ProPhoto images with embedded ICC.
    • Run them through a “Convert to sRGB” preprocessing node using convert_image_to_srgb. Verify colors match the original when displayed in color-managed viewers (no desaturation).
    • Save the output with icc_profile=prof_bytes from load_system_srgb_profile(); verify the saved file has sRGB ICC embedded (exiftool, or any ICC inspector).
  • Tiling quality:
    • Process a large image with a 4x model using tile_size=256, overlap=64, batch_size=2.
    • Inspect overlap boundaries at 200–400% zoom. With cosine feathering and weighted normalization, seams should be invisible.
  • Normalization detection:
    • Test at least one known diffusion ONNX upscaler and one GAN/ESRGAN-type model.
    • Ensure auto-detection selects 'diffusion' for the former and 'gan' for the latter by logging the chosen mode.
  • Equivalence check:
    • If your GPU memory allows, compare full-frame vs tiled output (PSNR/SSIM on overlapping regions). They should be numerically very close (only minor floating-point differences).

Additional notes

  • The tiling logic is designed for dynamic models, but it also works when you know the model’s static input size (then tile_size should match it). For ChaiNNer, expose tiling controls at the node level, with a toggle to disable tiling if the model and VRAM allow full-frame inference.
  • For best results, ensure overlap is large enough to cover the model’s effective receptive field near tile borders. Typical starting points: overlap ≈ tile_size/4 to tile_size/2.
  • The normalization detection runs once per model/session and can be cached at the node level to avoid repeated warm-up calls.

Thank you for considering these improvements. This would bring color accuracy for wide-gamut inputs, seam-free ONNX tiling comparable to PyTorch-quality compositing, and seamless support for diffusion-based ONNX upscalers via automatic normalization detection.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions