@@ -193,6 +193,9 @@ def get_const_sched_prob():
193
193
def mix_data (example ):
194
194
"""Function to mix the different datasets according to a schedule."""
195
195
del example
196
+ # This block computes the probability of mixing the primary task with
197
+ # the secondary tasks. 0 = only the primary task, 1 = only the secondary
198
+ # tasks.
196
199
if hparams .multiproblem_mixing_schedule == MixingSchedule .EXPONENTIAL :
197
200
prob = get_exp_sched_prob ()
198
201
elif hparams .multiproblem_mixing_schedule == MixingSchedule .CONSTANT :
@@ -203,8 +206,10 @@ def mix_data(example):
203
206
tf .logging .info ("Using the %s schedule to "
204
207
"train the MultiProblem." % str (
205
208
hparams .multiproblem_mixing_schedule ))
209
+ tf .logging .info ("Schedule mixing threshold "
210
+ "%.2f" % hparams .multiproblem_schedule_threshold )
206
211
207
- def sample_task (curr_task , num_tasks_left ):
212
+ def sample_task (curr_task , num_tasks_left , randnum ):
208
213
"""A recursive function to sample a task.
209
214
210
215
This function treats the probability as the threshold for the primary
@@ -214,6 +219,7 @@ def sample_task(curr_task, num_tasks_left):
214
219
Args:
215
220
curr_task: The index of the task being considered for sampling.
216
221
num_tasks_left: Number of tasks remaining to possibly sample from.
222
+ randnum: The random number used to select the dataset.
217
223
218
224
Returns:
219
225
A Tensor representing an example from the task that was sampled
@@ -222,23 +228,21 @@ def sample_task(curr_task, num_tasks_left):
222
228
223
229
if num_tasks_left == 0 :
224
230
return get_next_from_dataset (dataset_iterators [curr_task ])
225
- elif curr_task == 0 :
226
- # primary task
227
- return tf .cond (
228
- tf .greater (tf .random_uniform ([]), prob ),
229
- lambda d = dataset_iterators [0 ]: get_next_from_dataset (d ),
230
- lambda c = curr_task + 1 , n = num_tasks_left - 1 : sample_task (c , n )
231
- )
232
- # divide the probability mass across all the secondary tasks equally.
233
- new_prob = prob - curr_task * prob / (len (self .task_list )- 1 )
231
+
232
+ # When curr_task is 0, the primary task, the new prob is the same as
233
+ # the original probability. `tf.greater` indicates that the primary
234
+ # task receives (1-prob) of the probability mass.
235
+ # Otherwise, `prob` is divided equally amongst all the secondary
236
+ # tasks.
237
+ new_prob = prob - (curr_task * prob / (len (self .task_list )- 1 ))
234
238
return tf .cond (
235
- tf .greater (tf . random_uniform ([]) , new_prob ),
236
- lambda d = dataset_iterators [curr_task ]: get_next_from_dataset ( d ),
237
- lambda c = curr_task + 1 , n = num_tasks_left - 1 : sample_task ( c , n )
239
+ tf .greater (randnum , new_prob ),
240
+ lambda : get_next_from_dataset ( dataset_iterators [curr_task ]),
241
+ lambda : sample_task ( curr_task + 1 , num_tasks_left - 1 , randnum )
238
242
)
239
243
240
244
return tf .data .Dataset .from_tensors (
241
- sample_task (0 , len (self .task_list )- 1 ))
245
+ sample_task (0 , len (self .task_list )- 1 , tf . random_uniform ([]) ))
242
246
243
247
single_mtl_dataset = tf .data .Dataset .from_tensors (tf .zeros ([1 ])).repeat ()
244
248
single_mtl_dataset = single_mtl_dataset .flat_map (mix_data )
0 commit comments