Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make it possible to install with pip #81

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: Python package

on: [push, pull_request]

jobs:
build:

runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.9]

steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install
run: pip install .
- name: Test imports work
run: python scripts/test_import.py
18 changes: 2 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import argparse, os, sys, datetime, glob, importlib
import argparse, os, sys, datetime, glob
from omegaconf import OmegaConf
import numpy as np
from PIL import Image
Expand All @@ -12,15 +12,7 @@
from pytorch_lightning.utilities.distributed import rank_zero_only

from taming.data.utils import custom_collate


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)

from taming.util import instantiate_from_config

def get_parser(**parser_kwargs):
def str2bool(v):
Expand Down Expand Up @@ -113,12 +105,6 @@ def nondefault_trainer_args(opt):
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))


def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))


class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
Expand Down
2 changes: 1 addition & 1 deletion scripts/make_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from omegaconf import OmegaConf
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DataModuleFromConfig is unused in the scripts.

from taming.util import instantiate_from_config
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate
from tqdm import trange
Expand Down
2 changes: 1 addition & 1 deletion scripts/sample_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import streamlit as st
from streamlit import caching
from PIL import Image
from main import instantiate_from_config, DataModuleFromConfig
from taming.util import instantiate_from_config
from torch.utils.data import DataLoader
from torch.utils.data.dataloader import default_collate

Expand Down
2 changes: 1 addition & 1 deletion scripts/sample_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from tqdm import tqdm, trange
from einops import repeat

from main import instantiate_from_config
from taming.modules.transformer.mingpt import sample_with_past
from taming.util import instantiate_from_config


rescale = lambda x: (x + 1.) / 2.
Expand Down
3 changes: 3 additions & 0 deletions scripts/test_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# A simple test to ensure that packages import correctly
import taming.models.cond_transformer
import taming.models.vqgan
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from setuptools import setup, find_packages
from setuptools import setup, find_namespace_packages

setup(
name='taming-transformers',
version='0.0.1',
description='Taming Transformers for High-Resolution Image Synthesis',
packages=find_packages(),
packages=find_namespace_packages(include=["taming", "taming.*"]),
install_requires=[
'torch',
'numpy',
'tqdm',
'pytorch-lightning',
'einops'
],
)
2 changes: 1 addition & 1 deletion taming/models/cond_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch.nn.functional as F
import pytorch_lightning as pl

from main import instantiate_from_config
from taming.modules.util import SOSProvider
from taming.util import instantiate_from_config


def disabled_train(self, mode=True):
Expand Down
7 changes: 3 additions & 4 deletions taming/models/vqgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import torch.nn.functional as F
import pytorch_lightning as pl

from main import instantiate_from_config

from taming.modules.diffusionmodules.model import Encoder, Decoder
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from taming.modules.vqvae.quantize import GumbelQuantize
from taming.modules.vqvae.quantize import EMAVectorQuantizer
from taming.util import instantiate_from_config


class VQModel(pl.LightningModule):
def __init__(self,
Expand Down Expand Up @@ -401,4 +400,4 @@ def configure_optimizers(self):
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
return [opt_ae, opt_disc], []
16 changes: 16 additions & 0 deletions taming/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import os, hashlib
import requests
from tqdm import tqdm
Expand Down Expand Up @@ -142,6 +143,20 @@ def retrieve(
return list_or_dict, success


def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))


if __name__ == "__main__":
config = {"keya": "a",
"keyb": "b",
Expand All @@ -155,3 +170,4 @@ def retrieve(
print(config)
retrieve(config, "keya")