Skip to content

Commit cb681a4

Browse files
Support reproducible test generation both in single process mode and in parallel (#97)
Fixes #95.
1 parent 350ce45 commit cb681a4

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

grammarinator/generate.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ def __exit__(self, exc_type, exc_val, exc_tb):
9292
if self.cleanup and self.out_format:
9393
rmtree(dirname(self.out_format))
9494

95-
def __call__(self, index, *args, weights=None, lock=None, **kwargs):
95+
def __call__(self, index, *args, seed=None, weights=None, lock=None, **kwargs):
96+
if seed:
97+
random.seed(seed + index)
9698
weights = weights if weights is not None else {}
9799
lock = lock or nullcontext()
98100
return self.create_new_test(index, weights, lock)[0]
@@ -265,17 +267,14 @@ def restricted_float(value):
265267
parser.add_argument('-n', default=1, type=int, metavar='NUM',
266268
help='number of tests to generate, \'inf\' for continuous generation (default: %(default)s).')
267269
parser.add_argument('--random-seed', type=int, metavar='NUM',
268-
help='initialize random number generator with fixed seed (not set by default; noneffective if parallelization is enabled).')
270+
help='initialize random number generator with fixed seed (not set by default).')
269271
add_jobs_argument(parser)
270272
add_sys_path_argument(parser)
271273
add_sys_recursion_limit_argument(parser)
272274
add_log_level_argument(parser, short_alias=())
273275
add_version_argument(parser, version=__version__)
274276
args = parser.parse_args()
275277

276-
if args.jobs == 1 and args.random_seed:
277-
random.seed(args.random_seed)
278-
279278
init_logging()
280279
process_log_level_argument(args, logger)
281280
process_sys_path_argument(args)
@@ -293,14 +292,14 @@ def restricted_float(value):
293292
cleanup=False, encoding=args.encoding) as generator:
294293
if args.jobs > 1:
295294
with Manager() as manager:
296-
generator = partial(generator, weights=manager.dict(), lock=manager.Lock()) # pylint: disable=no-member
295+
generator = partial(generator, seed=args.random_seed, weights=manager.dict(), lock=manager.Lock()) # pylint: disable=no-member
297296
with Pool(args.jobs) as pool:
298297
for _ in pool.imap_unordered(generator, count(0) if args.n == inf else range(args.n)):
299298
pass
300299
else:
301300
weights = {}
302301
for i in count(0) if args.n == inf else range(args.n):
303-
generator(i, weights=weights)
302+
generator(i, seed=args.random_seed, weights=weights)
304303

305304

306305
if __name__ == '__main__':

0 commit comments

Comments
 (0)