diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 00000000..0ebfbb6f --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,22 @@ +name: Python package + +on: [push, pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install + run: pip install . + - name: Test imports work + run: python scripts/test_import.py diff --git a/main.py b/main.py index 3d83cb21..69443126 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -import argparse, os, sys, datetime, glob, importlib +import argparse, os, sys, datetime, glob from omegaconf import OmegaConf import numpy as np from PIL import Image @@ -12,15 +12,7 @@ from pytorch_lightning.utilities.distributed import rank_zero_only from taming.data.utils import custom_collate - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) - +from taming.util import instantiate_from_config def get_parser(**parser_kwargs): def str2bool(v): @@ -113,12 +105,6 @@ def nondefault_trainer_args(opt): return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k)) -def instantiate_from_config(config): - if not "target" in config: - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - class WrappedDataset(Dataset): """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" def __init__(self, dataset): diff --git a/scripts/make_samples.py b/scripts/make_samples.py index 5e4d6995..ee3de384 100644 --- a/scripts/make_samples.py +++ b/scripts/make_samples.py @@ -3,7 +3,7 @@ import numpy as np from omegaconf import OmegaConf from PIL import Image -from main import instantiate_from_config, DataModuleFromConfig +from taming.util import instantiate_from_config from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate from tqdm import trange diff --git a/scripts/sample_conditional.py b/scripts/sample_conditional.py index 174cf2af..ac13c03c 100644 --- a/scripts/sample_conditional.py +++ b/scripts/sample_conditional.py @@ -5,7 +5,7 @@ import streamlit as st from streamlit import caching from PIL import Image -from main import instantiate_from_config, DataModuleFromConfig +from taming.util import instantiate_from_config from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate diff --git a/scripts/sample_fast.py b/scripts/sample_fast.py index ff546c7d..12ba50ec 100644 --- a/scripts/sample_fast.py +++ b/scripts/sample_fast.py @@ -7,8 +7,8 @@ from tqdm import tqdm, trange from einops import repeat -from main import instantiate_from_config from taming.modules.transformer.mingpt import sample_with_past +from taming.util import instantiate_from_config rescale = lambda x: (x + 1.) / 2. diff --git a/scripts/test_import.py b/scripts/test_import.py new file mode 100644 index 00000000..aef93bef --- /dev/null +++ b/scripts/test_import.py @@ -0,0 +1,3 @@ +# A simple test to ensure that packages import correctly +import taming.models.cond_transformer +import taming.models.vqgan diff --git a/setup.py b/setup.py index a220d12b..3b213ff3 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,15 @@ -from setuptools import setup, find_packages +from setuptools import setup, find_namespace_packages setup( name='taming-transformers', version='0.0.1', description='Taming Transformers for High-Resolution Image Synthesis', - packages=find_packages(), + packages=find_namespace_packages(include=["taming", "taming.*"]), install_requires=[ 'torch', 'numpy', 'tqdm', + 'pytorch-lightning', + 'einops' ], ) diff --git a/taming/models/cond_transformer.py b/taming/models/cond_transformer.py index e4c63730..806c1732 100644 --- a/taming/models/cond_transformer.py +++ b/taming/models/cond_transformer.py @@ -3,8 +3,8 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config from taming.modules.util import SOSProvider +from taming.util import instantiate_from_config def disabled_train(self, mode=True): diff --git a/taming/models/vqgan.py b/taming/models/vqgan.py index a6950baa..7d69105b 100644 --- a/taming/models/vqgan.py +++ b/taming/models/vqgan.py @@ -2,12 +2,11 @@ import torch.nn.functional as F import pytorch_lightning as pl -from main import instantiate_from_config - from taming.modules.diffusionmodules.model import Encoder, Decoder from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer from taming.modules.vqvae.quantize import GumbelQuantize -from taming.modules.vqvae.quantize import EMAVectorQuantizer +from taming.util import instantiate_from_config + class VQModel(pl.LightningModule): def __init__(self, @@ -401,4 +400,4 @@ def configure_optimizers(self): lr=lr, betas=(0.5, 0.9)) opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9)) - return [opt_ae, opt_disc], [] \ No newline at end of file + return [opt_ae, opt_disc], [] diff --git a/taming/util.py b/taming/util.py index 06053e5d..bad6251c 100644 --- a/taming/util.py +++ b/taming/util.py @@ -1,3 +1,4 @@ +import importlib import os, hashlib import requests from tqdm import tqdm @@ -142,6 +143,20 @@ def retrieve( return list_or_dict, success +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + if __name__ == "__main__": config = {"keya": "a", "keyb": "b", @@ -155,3 +170,4 @@ def retrieve( print(config) retrieve(config, "keya") +