Skip to content

Commit 214a3cc

Browse files
T2T TeamCopybara-Service
authored andcommitted
preventing datasets from batching between videos.
PiperOrigin-RevId: 200934714
1 parent eb11883 commit 214a3cc

File tree

3 files changed

+66
-17
lines changed

3 files changed

+66
-17
lines changed

tensor2tensor/data_generators/video_generated.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
import numpy as np
2424

25-
from tensor2tensor.data_generators import problem
2625
from tensor2tensor.data_generators import video_utils
2726
from tensor2tensor.utils import metrics
2827
from tensor2tensor.utils import registry
@@ -54,7 +53,8 @@ def frame_width(self):
5453

5554
@property
5655
def total_number_of_frames(self):
57-
return 10000
56+
# 10k videos
57+
return 10000 * self.video_length
5858

5959
@property
6060
def video_length(self):
@@ -69,17 +69,6 @@ def eval_metrics(self):
6969
metrics.Metrics.IMAGE_RMSE]
7070
return eval_metrics
7171

72-
@property
73-
def dataset_splits(self):
74-
"""Splits of data to produce and number of output shards for each."""
75-
return [{
76-
"split": problem.DatasetSplit.TRAIN,
77-
"shards": 1,
78-
}, {
79-
"split": problem.DatasetSplit.EVAL,
80-
"shards": 1,
81-
}]
82-
8372
@property
8473
def extra_reading_spec(self):
8574
"""Additional data fields to store on disk and their decoders."""

tensor2tensor/data_generators/video_utils.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ def dataset_splits(self):
100100
"shards": 1,
101101
}]
102102

103+
@property
104+
def only_keep_videos_from_0th_frame(self):
105+
return True
106+
107+
@property
108+
def use_not_breaking_batching(self):
109+
return False
110+
103111
def preprocess_example(self, example, mode, hparams):
104112
"""Runtime preprocessing, e.g., resize example["frame"]."""
105113
return example
@@ -192,15 +200,64 @@ def features_from_batch(batched_prefeatures):
192200
# Batch and construct features.
193201
def _preprocess(example):
194202
return self.preprocess_example(example, mode, hparams)
203+
204+
def avoid_break_batching(dataset):
205+
"""Smart preprocessing to avoid break between videos!
206+
207+
Simple batching of images into videos may result into broken videos
208+
with two parts from two different videos. This preprocessing avoids
209+
this using the frame number.
210+
211+
Args:
212+
dataset: raw not-batched dataset.
213+
214+
Returns:
215+
batched not-broken videos.
216+
217+
"""
218+
def check_integrity_and_batch(*datasets):
219+
"""Checks whether a sequence of frames are from the same video.
220+
221+
Args:
222+
*datasets: datasets each skipping 1 frame from the previous one.
223+
224+
Returns:
225+
batched data and the integrity flag.
226+
"""
227+
frame_numbers = [dataset["frame_number"][0] for dataset in datasets]
228+
229+
not_broken = tf.equal(
230+
frame_numbers[-1] - frame_numbers[0], num_frames-1)
231+
if self.only_keep_videos_from_0th_frame:
232+
not_broken = tf.logical_and(not_broken, tf.equal(frame_numbers[0], 0))
233+
234+
features = {}
235+
for key in datasets[0].keys():
236+
values = [dataset[key] for dataset in datasets]
237+
batch = tf.stack(values)
238+
features[key] = batch
239+
return features, not_broken
240+
241+
ds = [dataset.skip(i) for i in range(num_frames)]
242+
dataset = tf.data.Dataset.zip(tuple(ds))
243+
dataset = dataset.map(check_integrity_and_batch)
244+
dataset = dataset.filter(lambda _, not_broken: not_broken)
245+
dataset = dataset.map(lambda features, _: features)
246+
247+
return dataset
248+
195249
preprocessed_dataset = dataset.map(_preprocess)
196250
num_frames = (hparams.video_num_input_frames +
197251
hparams.video_num_target_frames)
198252
# We jump by a random position at the beginning to add variety.
199253
if self.random_skip:
200254
random_skip = tf.random_uniform([], maxval=num_frames, dtype=tf.int64)
201255
preprocessed_dataset = preprocessed_dataset.skip(random_skip)
202-
batch_dataset = preprocessed_dataset.apply(
203-
tf.contrib.data.batch_and_drop_remainder(num_frames))
256+
if self.use_not_breaking_batching:
257+
batch_dataset = avoid_break_batching(preprocessed_dataset)
258+
else:
259+
batch_dataset = preprocessed_dataset.apply(
260+
tf.contrib.data.batch_and_drop_remainder(num_frames))
204261
dataset = batch_dataset.map(features_from_batch).shuffle(8)
205262
return dataset
206263

tensor2tensor/models/research/next_frame.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,10 +658,13 @@ def body(self, features):
658658

659659
all_actions = input_actions + target_actions
660660
all_rewards = input_rewards + target_rewards
661+
all_frames = input_frames + target_frames
662+
663+
tf.summary.image("full_video", tf.concat(all_frames, axis=1))
661664

662665
is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN
663666
gen_images, gen_rewards, latent_mean, latent_std = self.construct_model(
664-
images=input_frames + target_frames,
667+
images=all_frames,
665668
actions=all_actions,
666669
rewards=all_rewards,
667670
k=900.0 if is_training else -1.0,
@@ -730,7 +733,7 @@ def next_frame():
730733
def next_frame_stochastic():
731734
"""SV2P model."""
732735
hparams = next_frame()
733-
hparams.video_num_input_frames = 4
736+
hparams.video_num_input_frames = 2
734737
hparams.video_num_target_frames = 1
735738
hparams.batch_size = 8
736739
hparams.target_modality = "video:l2raw"

0 commit comments

Comments
 (0)