@@ -100,6 +100,14 @@ def dataset_splits(self):
100100 "shards" : 1 ,
101101 }]
102102
103+ @property
104+ def only_keep_videos_from_0th_frame (self ):
105+ return True
106+
107+ @property
108+ def use_not_breaking_batching (self ):
109+ return False
110+
103111 def preprocess_example (self , example , mode , hparams ):
104112 """Runtime preprocessing, e.g., resize example["frame"]."""
105113 return example
@@ -192,15 +200,64 @@ def features_from_batch(batched_prefeatures):
192200 # Batch and construct features.
193201 def _preprocess (example ):
194202 return self .preprocess_example (example , mode , hparams )
203+
204+ def avoid_break_batching (dataset ):
205+ """Smart preprocessing to avoid break between videos!
206+
207+ Simple batching of images into videos may result into broken videos
208+ with two parts from two different videos. This preprocessing avoids
209+ this using the frame number.
210+
211+ Args:
212+ dataset: raw not-batched dataset.
213+
214+ Returns:
215+ batched not-broken videos.
216+
217+ """
218+ def check_integrity_and_batch (* datasets ):
219+ """Checks whether a sequence of frames are from the same video.
220+
221+ Args:
222+ *datasets: datasets each skipping 1 frame from the previous one.
223+
224+ Returns:
225+ batched data and the integrity flag.
226+ """
227+ frame_numbers = [dataset ["frame_number" ][0 ] for dataset in datasets ]
228+
229+ not_broken = tf .equal (
230+ frame_numbers [- 1 ] - frame_numbers [0 ], num_frames - 1 )
231+ if self .only_keep_videos_from_0th_frame :
232+ not_broken = tf .logical_and (not_broken , tf .equal (frame_numbers [0 ], 0 ))
233+
234+ features = {}
235+ for key in datasets [0 ].keys ():
236+ values = [dataset [key ] for dataset in datasets ]
237+ batch = tf .stack (values )
238+ features [key ] = batch
239+ return features , not_broken
240+
241+ ds = [dataset .skip (i ) for i in range (num_frames )]
242+ dataset = tf .data .Dataset .zip (tuple (ds ))
243+ dataset = dataset .map (check_integrity_and_batch )
244+ dataset = dataset .filter (lambda _ , not_broken : not_broken )
245+ dataset = dataset .map (lambda features , _ : features )
246+
247+ return dataset
248+
195249 preprocessed_dataset = dataset .map (_preprocess )
196250 num_frames = (hparams .video_num_input_frames +
197251 hparams .video_num_target_frames )
198252 # We jump by a random position at the beginning to add variety.
199253 if self .random_skip :
200254 random_skip = tf .random_uniform ([], maxval = num_frames , dtype = tf .int64 )
201255 preprocessed_dataset = preprocessed_dataset .skip (random_skip )
202- batch_dataset = preprocessed_dataset .apply (
203- tf .contrib .data .batch_and_drop_remainder (num_frames ))
256+ if self .use_not_breaking_batching :
257+ batch_dataset = avoid_break_batching (preprocessed_dataset )
258+ else :
259+ batch_dataset = preprocessed_dataset .apply (
260+ tf .contrib .data .batch_and_drop_remainder (num_frames ))
204261 dataset = batch_dataset .map (features_from_batch ).shuffle (8 )
205262 return dataset
206263
0 commit comments