Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 64dc8cf

Browse files
urvashikCopybara-Service
authored and
Copybara-Service
committed
Fix for dataset mixing.
PiperOrigin-RevId: 208908276
1 parent 8ff6ec4 commit 64dc8cf

File tree

1 file changed

+18
-14
lines changed

1 file changed

+18
-14
lines changed

tensor2tensor/data_generators/multi_problem.py

+18-14
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ def get_const_sched_prob():
193193
def mix_data(example):
194194
"""Function to mix the different datasets according to a schedule."""
195195
del example
196+
# This block computes the probability of mixing the primary task with
197+
# the secondary tasks. 0 = only the primary task, 1 = only the secondary
198+
# tasks.
196199
if hparams.multiproblem_mixing_schedule == MixingSchedule.EXPONENTIAL:
197200
prob = get_exp_sched_prob()
198201
elif hparams.multiproblem_mixing_schedule == MixingSchedule.CONSTANT:
@@ -203,8 +206,10 @@ def mix_data(example):
203206
tf.logging.info("Using the %s schedule to "
204207
"train the MultiProblem." % str(
205208
hparams.multiproblem_mixing_schedule))
209+
tf.logging.info("Schedule mixing threshold "
210+
"%.2f" % hparams.multiproblem_schedule_threshold)
206211

207-
def sample_task(curr_task, num_tasks_left):
212+
def sample_task(curr_task, num_tasks_left, randnum):
208213
"""A recursive function to sample a task.
209214
210215
This function treats the probability as the threshold for the primary
@@ -214,6 +219,7 @@ def sample_task(curr_task, num_tasks_left):
214219
Args:
215220
curr_task: The index of the task being considered for sampling.
216221
num_tasks_left: Number of tasks remaining to possibly sample from.
222+
randnum: The random number used to select the dataset.
217223
218224
Returns:
219225
A Tensor representing an example from the task that was sampled
@@ -222,23 +228,21 @@ def sample_task(curr_task, num_tasks_left):
222228

223229
if num_tasks_left == 0:
224230
return get_next_from_dataset(dataset_iterators[curr_task])
225-
elif curr_task == 0:
226-
# primary task
227-
return tf.cond(
228-
tf.greater(tf.random_uniform([]), prob),
229-
lambda d=dataset_iterators[0]: get_next_from_dataset(d),
230-
lambda c=curr_task+1, n=num_tasks_left-1: sample_task(c, n)
231-
)
232-
# divide the probability mass across all the secondary tasks equally.
233-
new_prob = prob - curr_task * prob / (len(self.task_list)-1)
231+
232+
# When curr_task is 0, the primary task, the new prob is the same as
233+
# the original probability. `tf.greater` indicates that the primary
234+
# task receives (1-prob) of the probability mass.
235+
# Otherwise, `prob` is divided equally amongst all the secondary
236+
# tasks.
237+
new_prob = prob - (curr_task * prob / (len(self.task_list)-1))
234238
return tf.cond(
235-
tf.greater(tf.random_uniform([]), new_prob),
236-
lambda d=dataset_iterators[curr_task]: get_next_from_dataset(d),
237-
lambda c=curr_task+1, n=num_tasks_left-1: sample_task(c, n)
239+
tf.greater(randnum, new_prob),
240+
lambda: get_next_from_dataset(dataset_iterators[curr_task]),
241+
lambda: sample_task(curr_task+1, num_tasks_left-1, randnum)
238242
)
239243

240244
return tf.data.Dataset.from_tensors(
241-
sample_task(0, len(self.task_list)-1))
245+
sample_task(0, len(self.task_list)-1, tf.random_uniform([])))
242246

243247
single_mtl_dataset = tf.data.Dataset.from_tensors(tf.zeros([1])).repeat()
244248
single_mtl_dataset = single_mtl_dataset.flat_map(mix_data)

0 commit comments

Comments
 (0)