Skip to content
This repository was archived by the owner on Mar 22, 2025. It is now read-only.

Commit 4e4c026

Browse files
committed
Implement --include-subfolders argument in train.py
1 parent 2c60838 commit 4e4c026

File tree

4 files changed

+27
-6
lines changed

4 files changed

+27
-6
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ The frame filenames should have zero-padded frame numbers, for example like this
2424
If you have multiple sequences of frames (i.e. from different videos/scenes/shots), you can have different prefixes in the frame filenames, like this:
2525
* firstvideo00001.png, firstvideo00002.png, firstvideo00003.png, ..., secondvideo00001.png, secondvideo00002.png, secondvideo00003.png, ...
2626

27+
Alternatively, the different frame sequences can reside in different subfolders. For that to work, you have to use the `--include-subfolders` argument.
28+
2729
## Apply video colorization to a folder of PNG frames
2830

2931
`python -m tcvc.apply --input-path /path/to/images/ --input-style line_art`

tcvc/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from tcvc.dataset import DatasetFromFolder
44

55

6-
def get_dataset(root_dir, use_line_art=True):
7-
return DatasetFromFolder(root_dir, use_line_art)
6+
def get_dataset(root_dir, use_line_art=True, include_subfolders=False):
7+
return DatasetFromFolder(
8+
root_dir, use_line_art, include_subfolders=include_subfolders
9+
)
810

911

1012
def create_iterator(sample_size, sample_dataset):

tcvc/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def __init__(self, image_dir, use_line_art=True, include_subfolders=False):
1717
self.image_file_paths = get_image_file_paths(
1818
image_dir, include_subfolders=include_subfolders
1919
)
20+
assert len(self.image_file_paths) > 0
2021
transform_list = [ToTensor()]
2122
self.transform = Compose(transform_list)
2223

tcvc/train.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from tcvc.util import stitch_images, postprocess
1919

2020
if __name__ == "__main__":
21-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
21+
os.environ[
22+
"CUDA_VISIBLE_DEVICES"
23+
] = "0" # Ensure that we only use one GPU, not multiple
2224

2325
# Training settings
2426
parser = argparse.ArgumentParser(
@@ -30,13 +32,19 @@
3032
required=True,
3133
help="Path to a folder that contains the training set (image frames)",
3234
)
35+
parser.add_argument(
36+
"--include-subfolders",
37+
dest="include_subfolders",
38+
action="store_true",
39+
help="Include images from subfolders in the specified dataset path.",
40+
)
3341
parser.add_argument(
3442
"--input-style",
3543
dest="input_style",
3644
type=str,
3745
choices=["line_art", "greyscale"],
3846
help="line_art (canny edge detection) or greyscale",
39-
default="line_art",
47+
default="greyscale",
4048
)
4149
parser.add_argument("--logfile", required=False, default="training_logs.dat")
4250
parser.add_argument("--checkpoint", required=False, help="load pre-trained?")
@@ -104,9 +112,17 @@
104112
torch.cuda.manual_seed(opt.seed)
105113

106114
print("===> Loading datasets")
107-
train_set = get_dataset(opt.dataset, use_line_art=opt.input_style == "line_art")
115+
train_set = get_dataset(
116+
opt.dataset,
117+
use_line_art=opt.input_style == "line_art",
118+
include_subfolders=opt.include_subfolders,
119+
)
108120
# TODO: Add a separate argument for test set path. Do not use the same paths for training and testing
109-
test_set = get_dataset(opt.dataset, use_line_art=opt.input_style == "line_art")
121+
test_set = get_dataset(
122+
opt.dataset,
123+
use_line_art=opt.input_style == "line_art",
124+
include_subfolders=opt.include_subfolders,
125+
)
110126

111127
training_data_loader = DataLoader(
112128
dataset=train_set,

0 commit comments

Comments
 (0)