Skip to content

Commit 258cc6c

Browse files
Merge pull request #288 from ThomasGjerde/fix-restored-reporters
Set new reporters on restored species object
2 parents aff359e + e0afd3a commit 258cc6c

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

neat/population.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self, config, initial_state=None):
4848
self.species.speciate(config, self.population, self.generation)
4949
else:
5050
self.population, self.species, self.generation = initial_state
51+
self.species.reporters = self.reporters
5152
# If the reproduction object has a genome indexer,
5253
# set it to continue from the last genome ID.
5354
if hasattr(self.reproduction, "genome_indexer"):

tests/test_population.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,52 @@ def eval_genomes(genomes, config):
6969
last_genome_key + 1
7070
)
7171

72+
def test_reporter_consistency_after_checkpoint_restore(self):
73+
"""
74+
Test that ReportSets in the different objects in population are the same
75+
after restoring from a checkpoint.
76+
"""
77+
# Load configuration.
78+
local_dir = os.path.dirname(__file__)
79+
config_path = os.path.join(local_dir, 'test_configuration')
80+
config = neat.Config(neat.DefaultGenome, neat.DefaultReproduction,
81+
neat.DefaultSpeciesSet, neat.DefaultStagnation,
82+
config_path)
83+
84+
p = neat.Population(config)
85+
filename_prefix = 'neat-checkpoint-test_population'
86+
checkpointer = neat.Checkpointer(1, 5, filename_prefix=filename_prefix)
87+
p.add_reporter(checkpointer)
88+
89+
reporter_set = p.reporters
90+
self.assertEqual(reporter_set, p.reproduction.reporters)
91+
self.assertEqual(reporter_set, p.species.reporters)
92+
93+
def eval_genomes(genomes, config):
94+
for genome_id, genome in genomes:
95+
genome.fitness = 0.5
96+
97+
p.run(eval_genomes, 5)
98+
99+
filename = '{0}{1}'.format(
100+
filename_prefix, checkpointer.last_generation_checkpoint
101+
)
102+
restored_population = neat.Checkpointer.restore_checkpoint(filename)
103+
104+
# Check that the reporters are consistent
105+
restored_reporter_set = restored_population.reporters
106+
self.assertEqual(
107+
restored_reporter_set,
108+
restored_population.reproduction.reporters,
109+
msg="Reproduction reporters do not match after restore"
110+
)
111+
self.assertEqual(
112+
restored_reporter_set,
113+
restored_population.species.reporters,
114+
msg="Species reporters do not match after restore"
115+
)
116+
117+
72118
# def test_minimal():
73119
# # sample fitness function
74120
# def eval_fitness(population):

0 commit comments

Comments
 (0)