Skip to content

Commit cb56101

Browse files
daiyiplangfun authors
authored andcommitted
lf.eval.v2.EvaluationState to release processed examples right after evaluating them.
This allows memory intensive benchmarks to free processed examples once it's evaluated. PiperOrigin-RevId: 736348489
1 parent 07fd8fc commit cb56101

File tree

5 files changed

+111
-38
lines changed

5 files changed

+111
-38
lines changed

langfun/core/eval/v2/checkpointing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ def on_experiment_start(
5353
self._load_experiment(runner, experiment)
5454

5555
example_ids_to_evaluate = current_run.examples_to_evaluate(experiment)
56-
if experiment.state.evaluated_examples:
56+
if experiment.state.ckpt_examples:
5757
loaded_example_ids = list(
58-
sorted(experiment.state.evaluated_examples.keys())
58+
sorted(experiment.state.ckpt_examples.keys())
5959
)
6060
example_ids_to_evaluate -= set(loaded_example_ids)
6161
example_ids_to_evaluate = list(sorted(example_ids_to_evaluate))
6262
experiment.info(
63-
f'{len(experiment.state.evaluated_examples)} examples '
63+
f'{len(experiment.state.ckpt_examples)} examples '
6464
'loaded from checkpoint files. Their outputs will be used '
6565
f'for recomputing metrics. Example IDs: {loaded_example_ids}.'
6666
)
@@ -316,7 +316,7 @@ def on_experiment_complete(
316316
writer = self._sequence_writer.pop(experiment.id)
317317
writer.close()
318318
experiment.info(
319-
f'{len(experiment.state.evaluated_examples)} examples are '
319+
f'{len(experiment.state.evaluation_status)} examples are '
320320
f'checkpointed to {writer.path}.'
321321
)
322322

langfun/core/eval/v2/checkpointing_test.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from langfun.core.eval.v2 import checkpointing
1919
from langfun.core.eval.v2 import eval_test_helper
2020
from langfun.core.eval.v2 import example as example_lib
21+
from langfun.core.eval.v2 import experiment as experiment_lib
2122
from langfun.core.eval.v2 import runners as runners_lib # pylint: disable=unused-import
2223
import pyglove as pg
2324

@@ -52,6 +53,26 @@ def f():
5253
self.assertEqual(len(list(iter(f))), 1)
5354

5455

56+
class ExampleCollector(experiment_lib.Plugin):
57+
"""Collects all examples."""
58+
59+
def _on_bound(self):
60+
super()._on_bound()
61+
self._examples = {}
62+
63+
@property
64+
def examples(self) -> dict[int, example_lib.Example]:
65+
return self._examples
66+
67+
def on_example_complete(
68+
self, runner: runners_lib.Runner,
69+
experiment: experiment_lib.Experiment,
70+
example: example_lib.Example,
71+
):
72+
assert experiment.is_leaf, None
73+
self._examples[example.id] = example
74+
75+
5576
class CheckpointerTest(unittest.TestCase):
5677

5778
def assert_found_in_log(self, experiment, message):
@@ -70,13 +91,15 @@ def test_checkpointing(self):
7091
experiment = eval_test_helper.test_experiment()
7192
checkpoint_filename = 'checkpoint.jsonl'
7293
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
94+
collector = ExampleCollector()
7395
run = experiment.run(
74-
root_dir, 'new', runner='sequential', plugins=[checkpointer]
96+
root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
7597
)
7698
num_processed = {}
7799
for leaf in experiment.leaf_nodes:
78100
for i in range(leaf.num_examples):
79-
example = leaf.state.get(i + 1)
101+
self.assertIn(i + 1, collector.examples)
102+
example = collector.examples[i + 1]
80103
ckpt = run.output_path_for(leaf, f'checkpoint_{example.id}.jsonl')
81104
if example.has_error:
82105
self.assertFalse(pg.io.path_exists(ckpt))
@@ -134,12 +157,15 @@ def test_loading_corrupted_checkpoint(self):
134157
experiment = eval_test_helper.TestEvaluation()
135158
checkpoint_filename = 'checkpoint.jsonl'
136159
checkpointer = checkpointing.PerExampleCheckpointer(checkpoint_filename)
160+
collector = ExampleCollector()
161+
137162
run = experiment.run(
138-
root_dir, 'new', runner='sequential', plugins=[checkpointer]
163+
root_dir, 'new', runner='sequential', plugins=[checkpointer, collector]
139164
)
140165
num_processed = {}
141166
for i in range(experiment.num_examples):
142-
example = experiment.state.get(i + 1)
167+
self.assertIn(i + 1, collector.examples)
168+
example = collector.examples[i + 1]
143169
ckpt = run.output_path_for(experiment, f'checkpoint_{example.id}.jsonl')
144170
if not example.has_error:
145171
self.assertTrue(pg.io.path_exists(ckpt))

langfun/core/eval/v2/evaluation.py

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -166,25 +166,24 @@ def evaluate(
166166
if pg.MISSING_VALUE == example.input:
167167
example.input = self.example_input_by_id(example.id)
168168

169-
cached = self._state.get(example.id)
170-
169+
checkpointed = self._state.ckpt_example(example.id)
171170
with pg.timeit('evaluate') as timeit, lf.track_usages() as usage_summary:
172-
if cached is None or cached.has_error:
171+
if checkpointed is None or checkpointed.has_error:
173172
example.start_time = time.time()
174173
self._process(example, raise_if_has_error=raise_if_has_error)
175174
else:
176-
example.start_time = cached.start_time
175+
example.start_time = checkpointed.start_time
177176

178-
# Use cached output and metadata obtained from the previous processing.
179-
example.output = cached.output
180-
example.metadata = cached.metadata
177+
# Use the output and metadata obtained from the previous processing.
178+
example.output = checkpointed.output
179+
example.metadata = checkpointed.metadata
181180
example.newly_processed = False
182181

183182
# For previously processed examples, we merge previous usages as
184183
# cached, so the usage summary will account previous usages, but as
185184
# cached.
186-
assert cached.usage_summary is not None
187-
usage_summary.merge(cached.usage_summary, as_cached=True)
185+
assert checkpointed.usage_summary is not None
186+
usage_summary.merge(checkpointed.usage_summary, as_cached=True)
188187

189188
# Recompute the metrics and metadata for the example even its processed
190189
# output and metadata were from the cache.
@@ -691,9 +690,29 @@ def _html_tree_view_css_styles(self) -> list[str]:
691690
class EvaluationState:
692691
"""Evaluation state."""
693692

693+
class ExampleStatus(pg.Object):
694+
"""Example state."""
695+
evaluated: Annotated[
696+
bool,
697+
'Whether the example is evaluated.'
698+
] = False
699+
700+
newly_processed: Annotated[
701+
bool,
702+
'Whether the example is newly processed.'
703+
] = False
704+
705+
has_error: Annotated[
706+
bool,
707+
'Whether the example has error.'
708+
] = False
709+
694710
def __init__(self):
695711
super().__init__()
696-
self._evaluated_examples: dict[int, example_lib.Example] = {}
712+
self._ckpt_examples: dict[int, example_lib.Example] = {}
713+
self._evaluation_status: dict[
714+
int, EvaluationState.ExampleStatus
715+
] = {}
697716

698717
def load(
699718
self,
@@ -715,17 +734,41 @@ def load(
715734
assert isinstance(example, example_lib.Example), example
716735
if filter is not None and not filter(example):
717736
continue
718-
self._evaluated_examples[example.id] = example
737+
example.newly_processed = False
738+
self._ckpt_examples[example.id] = example
719739

720740
@property
721-
def evaluated_examples(self) -> dict[int, example_lib.Example]:
722-
"""Returns the examples in the state."""
723-
return self._evaluated_examples
741+
def evaluation_status(self) -> dict[int, ExampleStatus]:
742+
"""Returns the evaluation status of the examples."""
743+
return self._evaluation_status
724744

725-
def get(self, example_id: int) -> example_lib.Example | None:
726-
"""Returns the example with the given ID."""
727-
return self._evaluated_examples.get(example_id)
745+
@property
746+
def ckpt_examples(self) -> dict[int, example_lib.Example]:
747+
"""Returns the unevaluated examples from checkpoints."""
748+
return self._ckpt_examples
749+
750+
def ckpt_example(self, example_id: int) -> example_lib.Example | None:
751+
"""Returns the unevaluated example from checkpoints for a given ID."""
752+
return self._ckpt_examples.get(example_id)
753+
754+
def get_status(self, example_id: int) -> ExampleStatus:
755+
"""Returns the evaluation status of the example."""
756+
return self._evaluation_status.get(
757+
example_id, EvaluationState.ExampleStatus()
758+
)
728759

729760
def update(self, example: example_lib.Example) -> None:
730761
"""Updates the state with the given example."""
731-
self._evaluated_examples[example.id] = example
762+
self._update_status(example)
763+
# Processed examples will be removed once it's done.
764+
self._ckpt_examples.pop(example.id, None)
765+
766+
def _update_status(self, example: example_lib.Example) -> None:
767+
"""Updates the evaluation status of the example."""
768+
self._evaluation_status[example.id] = (
769+
EvaluationState.ExampleStatus(
770+
evaluated=example.output != pg.MISSING_VALUE,
771+
newly_processed=example.newly_processed,
772+
has_error=example.has_error,
773+
)
774+
)

langfun/core/eval/v2/evaluation_test.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def my_inputs():
7878
def test_evaluate(self):
7979
exp = eval_test_helper.TestEvaluation()
8080
example = exp.evaluate(Example(id=3))
81-
self.assertIs(exp.state.get(3), example)
81+
self.assertTrue(exp.state.get_status(3).evaluated)
82+
self.assertTrue(exp.state.get_status(3).newly_processed)
83+
self.assertFalse(exp.state.get_status(3).has_error)
8284
self.assertTrue(example.newly_processed)
8385
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
8486
self.assertEqual(example.output, 6)
@@ -111,7 +113,7 @@ def test_evaluate(self):
111113
self.assertEqual(example.metadata, {})
112114
self.assertEqual(example.metric_metadata, dict(error='ValueError'))
113115

114-
def test_evaluate_with_state(self):
116+
def test_evaluate_withstate(self):
115117
eval_dir = os.path.join(tempfile.gettempdir(), 'test_eval')
116118
pg.io.mkdirs(eval_dir, exist_ok=True)
117119
state_file = os.path.join(eval_dir, 'state.jsonl')
@@ -121,13 +123,14 @@ def test_evaluate_with_state(self):
121123
self.assertTrue(example.newly_processed)
122124
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
123125
self.assertEqual(example.output, 6)
124-
self.assertEqual(len(exp._state.evaluated_examples), 1)
126+
self.assertEqual(len(exp.state.evaluation_status), 1)
125127
f.add(pg.to_json_str(example))
126128

127129
exp.reset()
128-
self.assertEqual(len(exp._state.evaluated_examples), 0)
130+
self.assertEqual(len(exp.state.ckpt_examples), 0)
129131
exp.load_state(state_file)
130-
self.assertEqual(len(exp._state.evaluated_examples), 1)
132+
self.assertEqual(len(exp.state.ckpt_examples), 1)
133+
self.assertEqual(len(exp.state.evaluation_status), 0)
131134
example = exp.evaluate(3)
132135
self.assertFalse(example.newly_processed)
133136
self.assertEqual(example.input, pg.Dict(x=2, y=4, groundtruth=6))
@@ -140,14 +143,14 @@ def test_evaluate_with_state(self):
140143

141144
# Test load_state with filter.
142145
exp.reset()
143-
self.assertEqual(len(exp._state.evaluated_examples), 0)
146+
self.assertEqual(len(exp.state.ckpt_examples), 0)
144147
exp.load_state(state_file, filter=lambda x: x.id == 3)
145-
self.assertEqual(len(exp._state.evaluated_examples), 1)
148+
self.assertEqual(len(exp.state.ckpt_examples), 1)
146149

147150
exp.reset()
148-
self.assertEqual(len(exp._state.evaluated_examples), 0)
151+
self.assertEqual(len(exp.state.ckpt_examples), 0)
149152
exp.load_state(state_file, filter=lambda x: x.id == 1)
150-
self.assertEqual(len(exp._state.evaluated_examples), 0)
153+
self.assertEqual(len(exp.state.ckpt_examples), 0)
151154

152155
def test_html_view(self):
153156
exp = eval_test_helper.TestEvaluation()

langfun/core/eval/v2/runners.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def _log_experiment_completion(self, experiment: Experiment):
181181
)
182182
num_from_checkpoint, num_processed = 0, 0
183183
for example_id in example_ids:
184-
example = experiment.state.get(example_id)
185-
if example.newly_processed:
184+
status = experiment.state.get_status(example_id)
185+
if status.newly_processed:
186186
num_processed += 1
187187
else:
188188
num_from_checkpoint += 1
@@ -358,7 +358,8 @@ def evaluate_item(
358358
"""Runs the evaluation example."""
359359
self.on_example_start(evaluation, item)
360360
item = evaluation.evaluate(
361-
item, raise_if_has_error=self.current_run.raise_if_has_error
361+
item,
362+
raise_if_has_error=self.current_run.raise_if_has_error,
362363
)
363364
self.on_example_complete(evaluation, item)
364365
return item

0 commit comments

Comments
 (0)