Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
dobrosketchkun authored Mar 20, 2021
1 parent 943f111 commit d017ac6
Showing 1 changed file with 398 additions and 0 deletions.
398 changes: 398 additions & 0 deletions files/Generate_images_(styleGAN2_ada_pythorch).ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,398 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "Generate images (styleGAN2-ada-pythorch)",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "aqIgy6kur1nB"
},
"source": [
"#@title Installs and useful functions {display-mode: \"form\"}\n",
"!git clone https://github.com/NVlabs/stylegan2-ada-pytorch/\n",
"!pip install ninja\n",
"\n",
"#@title # Useful utility functions...\n",
"%cd /content/stylegan2-ada-pytorch/\n",
"import os\n",
"import re\n",
"from typing import List, Optional\n",
"\n",
"import click\n",
"import dnnlib\n",
"import numpy as np\n",
"import PIL.Image\n",
"import torch\n",
"from io import BytesIO\n",
"import legacy\n",
"\n",
"import argparse\n",
"import numpy as np\n",
"import PIL.Image\n",
"import re\n",
"import sys\n",
"from io import BytesIO\n",
"import IPython.display\n",
"import numpy as np\n",
"from math import ceil\n",
"from PIL import Image, ImageDraw\n",
"import imageio\n",
"import os\n",
"import pickle\n",
"from google.colab import files\n",
"\n",
"\n",
"from IPython.display import clear_output \n",
"\n",
"def generate_images(zs, truncation_psi):\n",
" # Gs_kwargs = dnnlib.EasyDict()\n",
" # Gs_kwargs.output_transform = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)\n",
" # Gs_kwargs.randomize_noise = False\n",
" # if not isinstance(truncation_psi, list):\n",
" # truncation_psi = [truncation_psi] * len(zs)\n",
" \n",
" imgs = []\n",
" # for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = \"Generating images\"):\n",
" # Gs_kwargs.truncation_psi = truncation_psi[z_idx]\n",
" # noise_rnd = np.random.RandomState(1) # fix noise\n",
" # tflib.set_vars({var: noise_rnd.randn(*var.shape.as_list()) for var in noise_vars}) # [height, width]\n",
" # images = Gs.run(z, None, **Gs_kwargs) # [minibatch, height, width, channel]\n",
" # imgs.append(PIL.Image.fromarray(images[0], 'RGB'))\n",
"\n",
" # for seed_idx, seed in enumerate(seeds):\n",
" for z_idx, z in log_progress(enumerate(zs), size = len(zs), name = \"Generating images\"):\n",
" # print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds)))\n",
" # img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)\n",
" img = G(z, None, truncation_psi=truncation_psi, noise_mode=noise_mode)\n",
" img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
" imgs.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))\n",
" # PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')\n",
" return imgs\n",
"\n",
"def generate_zs_from_seeds(seeds):\n",
" zs = []\n",
" for seed_idx, seed in enumerate(seeds):\n",
" rnd = np.random.RandomState(seed)\n",
" # z = rnd.randn(1, *Gs.input_shape[1:]) # [minibatch, component]\n",
" z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device)\n",
" zs.append(z)\n",
" return zs\n",
"\n",
"# Generates a list of images, based on a list of seed for latent vectors (Z), and a list (or a single constant) of truncation_psi's.\n",
"def generate_images_from_seeds(seeds, truncation_psi):\n",
" return generate_images(generate_zs_from_seeds(seeds), truncation_psi)\n",
"\n",
"def saveImgs(imgs, location):\n",
" for idx, img in log_progress(enumerate(imgs), size = len(imgs), name=\"Saving images\"):\n",
" file = location+ str(idx) + \".png\"\n",
" img.save(file)\n",
"\n",
"def imshow(a, format='png', jpeg_fallback=True, save=False, id=None):\n",
" a = np.asarray(a, dtype=np.uint8)\n",
" str_file = BytesIO()\n",
" PIL.Image.fromarray(a).save(str_file, format)\n",
" #####\n",
" if save:\n",
" PIL.Image.fromarray(a).save(str(id[0]) + '.png', format)\n",
" ####\n",
" im_data = str_file.getvalue()\n",
" try:\n",
" disp = IPython.display.display(IPython.display.Image(im_data))\n",
" except IOError:\n",
" if jpeg_fallback and format != 'jpeg':\n",
" print ('Warning: image was too large to display in format \"{}\"; '\n",
" 'trying jpeg instead.').format(format)\n",
" return imshow(a, format='jpeg')\n",
" else:\n",
" raise\n",
" return disp\n",
"\n",
"def showarray(a, fmt='png'):\n",
" a = np.uint8(a)\n",
" f = StringIO()\n",
" PIL.Image.fromarray(a).save(f, fmt)\n",
" IPython.display.display(IPython.display.Image(data=f.getvalue()))\n",
"\n",
" \n",
"def clamp(x, minimum, maximum):\n",
" return max(minimum, min(x, maximum))\n",
" \n",
"def drawLatent(image,latents,x,y,x2,y2, color=(255,0,0,100)):\n",
" buffer = PIL.Image.new('RGBA', image.size, (0,0,0,0))\n",
" \n",
" draw = ImageDraw.Draw(buffer)\n",
" cy = (y+y2)/2\n",
" draw.rectangle([x,y,x2,y2],fill=(255,255,255,180), outline=(0,0,0,180))\n",
" for i in range(len(latents)):\n",
" mx = x + (x2-x)*(float(i)/len(latents))\n",
" h = (y2-y)*latents[i]*0.1\n",
" h = clamp(h,cy-y2,y2-cy)\n",
" draw.line((mx,cy,mx,cy+h),fill=color)\n",
" return PIL.Image.alpha_composite(image,buffer)\n",
" \n",
" \n",
"def createImageGrid(images, scale=0.25, rows=1):\n",
" w,h = images[0].size\n",
" w = int(w*scale)\n",
" h = int(h*scale)\n",
" height = rows*h\n",
" cols = ceil(len(images) / rows)\n",
" width = cols*w\n",
" canvas = PIL.Image.new('RGBA', (width,height), 'white')\n",
" for i,img in enumerate(images):\n",
" img = img.resize((w,h), PIL.Image.ANTIALIAS)\n",
" canvas.paste(img, (w*(i % cols), h*(i // cols))) \n",
" return canvas\n",
"\n",
"def convertZtoW(latent, truncation_psi=0.7, truncation_cutoff=9):\n",
" dlatent = Gs.components.mapping.run(latent, None) # [seed, layer, component]\n",
" dlatent_avg = Gs.get_var('dlatent_avg') # [component]\n",
" for i in range(truncation_cutoff):\n",
" dlatent[0][i] = (dlatent[0][i]-dlatent_avg)*truncation_psi + dlatent_avg\n",
" \n",
" return dlatent\n",
"\n",
"def interpolate(zs, steps):\n",
" out = []\n",
" for i in range(len(zs)-1):\n",
" for index in range(steps):\n",
" fraction = index/float(steps) \n",
" out.append(zs[i+1]*fraction + zs[i]*(1-fraction))\n",
" return out\n",
"\n",
"# Taken from https://github.com/alexanderkuk/log-progress\n",
"def log_progress(sequence, every=1, size=None, name='Items'):\n",
" from ipywidgets import IntProgress, HTML, VBox\n",
" from IPython.display import display\n",
"\n",
" is_iterator = False\n",
" if size is None:\n",
" try:\n",
" size = len(sequence)\n",
" except TypeError:\n",
" is_iterator = True\n",
" if size is not None:\n",
" if every is None:\n",
" if size <= 200:\n",
" every = 1\n",
" else:\n",
" every = int(size / 200) # every 0.5%\n",
" else:\n",
" assert every is not None, 'sequence is iterator, set every'\n",
"\n",
" if is_iterator:\n",
" progress = IntProgress(min=0, max=1, value=1)\n",
" progress.bar_style = 'info'\n",
" else:\n",
" progress = IntProgress(min=0, max=size, value=0)\n",
" label = HTML()\n",
" box = VBox(children=[label, progress])\n",
" display(box)\n",
"\n",
" index = 0\n",
" try:\n",
" for index, record in enumerate(sequence, 1):\n",
" if index == 1 or index % every == 0:\n",
" if is_iterator:\n",
" label.value = '{name}: {index} / ?'.format(\n",
" name=name,\n",
" index=index\n",
" )\n",
" else:\n",
" progress.value = index\n",
" label.value = u'{name}: {index} / {size}'.format(\n",
" name=name,\n",
" index=index,\n",
" size=size\n",
" )\n",
" yield record\n",
" except:\n",
" progress.bar_style = 'danger'\n",
" raise\n",
" else:\n",
" progress.bar_style = 'success'\n",
" progress.value = index\n",
" label.value = \"{name}: {index}\".format(\n",
" name=name,\n",
" index=str(index or '?')\n",
" )\n",
"\n",
"\n",
"\n",
"\n",
"clear_output()\n",
"\n",
"print('Done!')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LRteqFuhtINC"
},
"source": [
"#@title Download and load neural network {display-mode: \"form\"}\n",
"%cd /content/\n",
"import os\n",
"\n",
"#@markdown Paste neural netwok checkpoint google drive url:\n",
"url = '' #@param {type: \"string\"}\n",
"\n",
"\n",
"os.environ['ID']= url.split('/')[-2]\n",
"\n",
"!gdown --id $ID -O network.pkl\n",
"\n",
"%cd /content/stylegan2-ada-pytorch/\n",
"\n",
"outdir = '/content/out/'\n",
"seeds = [1,2,3]\n",
"truncation_psi = 0.7\n",
"noise_mode = 'const' # ['const', 'random', 'none']\n",
"network_pkl = '/content/network.pkl'\n",
"\n",
"print('Loading networks from \"%s\"...' % network_pkl)\n",
"device = torch.device('cuda')\n",
"with dnnlib.util.open_url(network_pkl) as f:\n",
" G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore\n",
"clear_output()\n",
"print('Done!')"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PfUS1eRQCA5c"
},
"source": [
"#@title Generate a single image{display-mode: \"form\"}\n",
"#@markdown Use -1 for random seed or type a specific seed (from 0 up to 2^32-1):\n",
"\n",
"rand = -1 #@param {type: \"number\"}\n",
"\n",
"if rand == -1:\n",
" rand = np.random.randint(4294967295, size=1)\n",
"else:\n",
" rand = [rand]\n",
"print(rand)\n",
"\n",
"imshow(generate_images_from_seeds(rand, truncation_psi=0.7)[0])"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "KALc4g9xDu0G"
},
"source": [
"#@title Generate grid of images{display-mode: \"form\"}\n",
"\n",
"size = 12 #@param {type: \"number\"}\n",
"\n",
"seeds = np.random.randint((2**32 - 1), size=size)\n",
"\n",
"rows = 5 #@param {type: \"number\"}\n",
"image_size = 0.3 #@param {type: \"number\"}\n",
"\n",
"\n",
"# np.random.shuffle(seeds)\n",
"\n",
"\n",
"def chunks(lst, n):\n",
" \"\"\"Yield successive n-sized chunks from lst.\"\"\"\n",
" for i in range(0, len(lst), n):\n",
" yield lst[i:i + n]\n",
"\n",
"for _ in chunks(list(seeds), ceil(size/rows)):\n",
" print(_)\n",
"print()\n",
"imshow(createImageGrid(generate_images_from_seeds(seeds, 0.7), image_size , rows))"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iLQu1ZKWFUMA"
},
"source": [
"#@title Generate interpolation video {display-mode: \"form\"}\n",
"\n",
"video_output_path = '/content/out/'\n",
"video_name = 'interpolation_movie.mp4'\n",
"movie_name = video_output_path + video_name\n",
"\n",
"\n",
"\n",
"try:\n",
" os.mkdir(video_output_path)\n",
"except:\n",
" pass\n",
"\n",
"size = 10 #@param {type: \"number\"}\n",
"seeds = list(np.random.randint((2**32) - 1, size=size))\n",
"\n",
"\n",
"print(seeds)\n",
"seeds = seeds + [seeds[0]]\n",
"zs = generate_zs_from_seeds(seeds)\n",
"\n",
"number_of_steps = 30 #@param {type: \"number\"}\n",
"trunc_psi = 0.7\n",
"imgs = generate_images(interpolate(zs,number_of_steps), trunc_psi)\n",
"\n",
"with imageio.get_writer(movie_name, mode='I') as writer:\n",
" for image in log_progress(list(imgs), name = \"Creating animation\"):\n",
" writer.append_data(np.array(image))\n",
"\n",
"download_video = True #@param {type:\"boolean\"}\n",
"if download_video:\n",
" files.download(movie_name) \n",
"\n"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "bJBUfTbDFgD9"
},
"source": [
"#@title View video in colab {display-mode: \"form\"}\n",
"\n",
"from IPython.display import HTML\n",
"from base64 import b64encode\n",
"mp4 = open(movie_name,'rb').read()\n",
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
"HTML(\"\"\"\n",
"<video width=700 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
"</video>\n",
"\"\"\" % data_url)"
],
"execution_count": null,
"outputs": []
}
]
}

0 comments on commit d017ac6

Please sign in to comment.