-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
943f111
commit d017ac6
Showing
1 changed file
with
398 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,398 @@ | ||
{ | ||
"nbformat": 4, | ||
"nbformat_minor": 0, | ||
"metadata": { | ||
"accelerator": "GPU", | ||
"colab": { | ||
"name": "Generate images (styleGAN2-ada-pythorch)", | ||
"provenance": [], | ||
"collapsed_sections": [], | ||
"toc_visible": true | ||
}, | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"name": "python3" | ||
} | ||
}, | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "aqIgy6kur1nB" | ||
}, | ||
"source": [ | ||
"#@title Installs and useful functions {display-mode: \"form\"}\n", | ||
"!git clone https://github.com/NVlabs/stylegan2-ada-pytorch/\n", | ||
"!pip install ninja\n", | ||
"\n", | ||
"#@title # Useful utility functions...\n", | ||
"%cd /content/stylegan2-ada-pytorch/\n", | ||
"import os\n", | ||
"import re\n", | ||
"from typing import List, Optional\n", | ||
"\n", | ||
"import click\n", | ||
"import dnnlib\n", | ||
"import numpy as np\n", | ||
"import PIL.Image\n", | ||
"import torch\n", | ||
"from io import BytesIO\n", | ||
"import legacy\n", | ||
"\n", | ||
"import argparse\n", | ||
"import numpy as np\n", | ||
"import PIL.Image\n", | ||
"import re\n", | ||
"import sys\n", | ||
"from io import BytesIO\n", | ||
"import IPython.display\n", | ||
"import numpy as np\n", | ||
"from math import ceil\n", | ||
"from PIL import Image, ImageDraw\n", | ||
"import imageio\n", | ||
"import os\n", | ||
"import pickle\n", | ||
"from google.colab import files\n", | ||
"\n", | ||
"\n", | ||
"from IPython.display import clear_output \n", | ||
"\n", | ||
"def generate_images(zs, truncation_psi):\n", | ||
" # Gs_kwargs = dnnlib.EasyDict()\n", | ||
" # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n", | ||
" # Gs_kwargs.randomize_noise = False\n", | ||
" # if not isinstance(truncation_psi, list):\n", | ||
" # truncation_psi = [truncation_psi] * len(zs)\n", | ||
" \n", | ||
" imgs = []\n", | ||
" # for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = \"Generating images\"):\n", | ||
" # Gs_kwargs.truncation_psi = truncation_psi[z_idx]\n", | ||
" # noise_rnd = np.random.RandomState(1) # fix noise\n", | ||
" # tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]\n", | ||
" # images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]\n", | ||
" # imgs.append(PIL.Image.fromarray(images[0], 'RGB'))\n", | ||
"\n", | ||
" # for seed_idx, seed in enumerate(seeds):\n", | ||
" for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = \"Generating images\"):\n", | ||
" # print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))\n", | ||
" # img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)\n", | ||
" img = G(z, None, truncation_psi=truncation_psi, noise_mode=noise_mode)\n", | ||
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n", | ||
" imgs.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))\n", | ||
" # PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')\n", | ||
" return imgs\n", | ||
"\n", | ||
"def generate_zs_from_seeds(seeds):\n", | ||
" zs = []\n", | ||
" for seed_idx, seed in enumerate(seeds):\n", | ||
" rnd = np.random.RandomState(seed)\n", | ||
" # z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]\n", | ||
" z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)\n", | ||
" zs.append(z)\n", | ||
" return zs\n", | ||
"\n", | ||
"# Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.\n", | ||
"def generate_images_from_seeds(seeds, truncation_psi):\n", | ||
" return generate_images(generate_zs_from_seeds(seeds), truncation_psi)\n", | ||
"\n", | ||
"def saveImgs(imgs, location):\n", | ||
" for idx, img in log_progress(enumerate(imgs), size = len(imgs), name=\"Saving images\"):\n", | ||
" file = location+ str(idx) + \".png\"\n", | ||
" img.save(file)\n", | ||
"\n", | ||
"def imshow(a, format='png', jpeg_fallback=True, save=False, id=None):\n", | ||
" a = np.asarray(a, dtype=np.uint8)\n", | ||
" str_file = BytesIO()\n", | ||
" PIL.Image.fromarray(a).save(str_file, format)\n", | ||
" #####\n", | ||
" if save:\n", | ||
" PIL.Image.fromarray(a).save(str(id[0]) + '.png', format)\n", | ||
" ####\n", | ||
" im_data = str_file.getvalue()\n", | ||
" try:\n", | ||
" disp = IPython.display.display(IPython.display.Image(im_data))\n", | ||
" except IOError:\n", | ||
" if jpeg_fallback and format != 'jpeg':\n", | ||
" print ('Warning: image was too large to display in format \"{}\"; '\n", | ||
" 'trying jpeg instead.').format(format)\n", | ||
" return imshow(a, format='jpeg')\n", | ||
" else:\n", | ||
" raise\n", | ||
" return disp\n", | ||
"\n", | ||
"def showarray(a, fmt='png'):\n", | ||
" a = np.uint8(a)\n", | ||
" f = StringIO()\n", | ||
" PIL.Image.fromarray(a).save(f, fmt)\n", | ||
" IPython.display.display(IPython.display.Image(data=f.getvalue()))\n", | ||
"\n", | ||
" \n", | ||
"def clamp(x, minimum, maximum):\n", | ||
" return max(minimum, min(x, maximum))\n", | ||
" \n", | ||
"def drawLatent(image,latents,x,y,x2,y2, color=(255,0,0,100)):\n", | ||
" buffer = PIL.Image.new('RGBA', image.size, (0,0,0,0))\n", | ||
" \n", | ||
" draw = ImageDraw.Draw(buffer)\n", | ||
" cy = (y+y2)/2\n", | ||
" draw.rectangle([x,y,x2,y2],fill=(255,255,255,180), outline=(0,0,0,180))\n", | ||
" for i in range(len(latents)):\n", | ||
" mx = x + (x2-x)*(float(i)/len(latents))\n", | ||
" h = (y2-y)*latents[i]*0.1\n", | ||
" h = clamp(h,cy-y2,y2-cy)\n", | ||
" draw.line((mx,cy,mx,cy+h),fill=color)\n", | ||
" return PIL.Image.alpha_composite(image,buffer)\n", | ||
" \n", | ||
" \n", | ||
"def createImageGrid(images, scale=0.25, rows=1):\n", | ||
" w,h = images[0].size\n", | ||
" w = int(w*scale)\n", | ||
" h = int(h*scale)\n", | ||
" height = rows*h\n", | ||
" cols = ceil(len(images) / rows)\n", | ||
" width = cols*w\n", | ||
" canvas = PIL.Image.new('RGBA', (width,height), 'white')\n", | ||
" for i,img in enumerate(images):\n", | ||
" img = img.resize((w,h), PIL.Image.ANTIALIAS)\n", | ||
" canvas.paste(img, (w*(i % cols), h*(i // cols))) \n", | ||
" return canvas\n", | ||
"\n", | ||
"def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):\n", | ||
" dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]\n", | ||
" dlatent_avg = Gs.get_var('dlatent_avg') # [component]\n", | ||
" for i in range(truncation_cutoff):\n", | ||
" dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg\n", | ||
" \n", | ||
" return dlatent\n", | ||
"\n", | ||
"def interpolate(zs, steps):\n", | ||
" out = []\n", | ||
" for i in range(len(zs)-1):\n", | ||
" for index in range(steps):\n", | ||
" fraction = index/float(steps) \n", | ||
" out.append(zs[i+1]*fraction + zs[i]*(1-fraction))\n", | ||
" return out\n", | ||
"\n", | ||
"# Taken from https://github.com/alexanderkuk/log-progress\n", | ||
"def log_progress(sequence, every=1, size=None, name='Items'):\n", | ||
" from ipywidgets import IntProgress, HTML, VBox\n", | ||
" from IPython.display import display\n", | ||
"\n", | ||
" is_iterator = False\n", | ||
" if size is None:\n", | ||
" try:\n", | ||
" size = len(sequence)\n", | ||
" except TypeError:\n", | ||
" is_iterator = True\n", | ||
" if size is not None:\n", | ||
" if every is None:\n", | ||
" if size <= 200:\n", | ||
" every = 1\n", | ||
" else:\n", | ||
" every = int(size / 200) # every 0.5%\n", | ||
" else:\n", | ||
" assert every is not None, 'sequence is iterator, set every'\n", | ||
"\n", | ||
" if is_iterator:\n", | ||
" progress = IntProgress(min=0, max=1, value=1)\n", | ||
" progress.bar_style = 'info'\n", | ||
" else:\n", | ||
" progress = IntProgress(min=0, max=size, value=0)\n", | ||
" label = HTML()\n", | ||
" box = VBox(children=[label, progress])\n", | ||
" display(box)\n", | ||
"\n", | ||
" index = 0\n", | ||
" try:\n", | ||
" for index, record in enumerate(sequence, 1):\n", | ||
" if index == 1 or index % every == 0:\n", | ||
" if is_iterator:\n", | ||
" label.value = '{name}: {index} / ?'.format(\n", | ||
" name=name,\n", | ||
" index=index\n", | ||
" )\n", | ||
" else:\n", | ||
" progress.value = index\n", | ||
" label.value = u'{name}: {index} / {size}'.format(\n", | ||
" name=name,\n", | ||
" index=index,\n", | ||
" size=size\n", | ||
" )\n", | ||
" yield record\n", | ||
" except:\n", | ||
" progress.bar_style = 'danger'\n", | ||
" raise\n", | ||
" else:\n", | ||
" progress.bar_style = 'success'\n", | ||
" progress.value = index\n", | ||
" label.value = \"{name}: {index}\".format(\n", | ||
" name=name,\n", | ||
" index=str(index or '?')\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"clear_output()\n", | ||
"\n", | ||
"print('Done!')" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "LRteqFuhtINC" | ||
}, | ||
"source": [ | ||
"#@title Download and load neural network {display-mode: \"form\"}\n", | ||
"%cd /content/\n", | ||
"import os\n", | ||
"\n", | ||
"#@markdown Paste neural netwok checkpoint google drive url:\n", | ||
"url = '' #@param {type: \"string\"}\n", | ||
"\n", | ||
"\n", | ||
"os.environ['ID']= url.split('/')[-2]\n", | ||
"\n", | ||
"!gdown --id $ID -O network.pkl\n", | ||
"\n", | ||
"%cd /content/stylegan2-ada-pytorch/\n", | ||
"\n", | ||
"outdir = '/content/out/'\n", | ||
"seeds = [1,2,3]\n", | ||
"truncation_psi = 0.7\n", | ||
"noise_mode = 'const' # ['const', 'random', 'none']\n", | ||
"network_pkl = '/content/network.pkl'\n", | ||
"\n", | ||
"print('Loading networks from \"%s\"...' % network_pkl)\n", | ||
"device = torch.device('cuda')\n", | ||
"with dnnlib.util.open_url(network_pkl) as f:\n", | ||
" G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore\n", | ||
"clear_output()\n", | ||
"print('Done!')" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "PfUS1eRQCA5c" | ||
}, | ||
"source": [ | ||
"#@title Generate a single image{display-mode: \"form\"}\n", | ||
"#@markdown Use -1 for random seed or type a specific seed (from 0 up to 2^32-1):\n", | ||
"\n", | ||
"rand = -1 #@param {type: \"number\"}\n", | ||
"\n", | ||
"if rand == -1:\n", | ||
" rand = np.random.randint(4294967295, size=1)\n", | ||
"else:\n", | ||
" rand = [rand]\n", | ||
"print(rand)\n", | ||
"\n", | ||
"imshow(generate_images_from_seeds(rand, truncation_psi=0.7)[0])" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "KALc4g9xDu0G" | ||
}, | ||
"source": [ | ||
"#@title Generate grid of images{display-mode: \"form\"}\n", | ||
"\n", | ||
"size = 12 #@param {type: \"number\"}\n", | ||
"\n", | ||
"seeds = np.random.randint((2**32 - 1), size=size)\n", | ||
"\n", | ||
"rows = 5 #@param {type: \"number\"}\n", | ||
"image_size = 0.3 #@param {type: \"number\"}\n", | ||
"\n", | ||
"\n", | ||
"# np.random.shuffle(seeds)\n", | ||
"\n", | ||
"\n", | ||
"def chunks(lst, n):\n", | ||
" \"\"\"Yield successive n-sized chunks from lst.\"\"\"\n", | ||
" for i in range(0, len(lst), n):\n", | ||
" yield lst[i:i + n]\n", | ||
"\n", | ||
"for _ in chunks(list(seeds), ceil(size/rows)):\n", | ||
" print(_)\n", | ||
"print()\n", | ||
"imshow(createImageGrid(generate_images_from_seeds(seeds, 0.7), image_size , rows))" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "iLQu1ZKWFUMA" | ||
}, | ||
"source": [ | ||
"#@title Generate interpolation video {display-mode: \"form\"}\n", | ||
"\n", | ||
"video_output_path = '/content/out/'\n", | ||
"video_name = 'interpolation_movie.mp4'\n", | ||
"movie_name = video_output_path + video_name\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"try:\n", | ||
" os.mkdir(video_output_path)\n", | ||
"except:\n", | ||
" pass\n", | ||
"\n", | ||
"size = 10 #@param {type: \"number\"}\n", | ||
"seeds = list(np.random.randint((2**32) - 1, size=size))\n", | ||
"\n", | ||
"\n", | ||
"print(seeds)\n", | ||
"seeds = seeds + [seeds[0]]\n", | ||
"zs = generate_zs_from_seeds(seeds)\n", | ||
"\n", | ||
"number_of_steps = 30 #@param {type: \"number\"}\n", | ||
"trunc_psi = 0.7\n", | ||
"imgs = generate_images(interpolate(zs,number_of_steps), trunc_psi)\n", | ||
"\n", | ||
"with imageio.get_writer(movie_name, mode='I') as writer:\n", | ||
" for image in log_progress(list(imgs), name = \"Creating animation\"):\n", | ||
" writer.append_data(np.array(image))\n", | ||
"\n", | ||
"download_video = True #@param {type:\"boolean\"}\n", | ||
"if download_video:\n", | ||
" files.download(movie_name) \n", | ||
"\n" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"metadata": { | ||
"id": "bJBUfTbDFgD9" | ||
}, | ||
"source": [ | ||
"#@title View video in colab {display-mode: \"form\"}\n", | ||
"\n", | ||
"from IPython.display import HTML\n", | ||
"from base64 import b64encode\n", | ||
"mp4 = open(movie_name,'rb').read()\n", | ||
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", | ||
"HTML(\"\"\"\n", | ||
"<video width=700 controls>\n", | ||
" <source src=\"%s\" type=\"video/mp4\">\n", | ||
"</video>\n", | ||
"\"\"\" % data_url)" | ||
], | ||
"execution_count": null, | ||
"outputs": [] | ||
} | ||
] | ||
} |