[Orbit]Handle iterator exhaustion in Controller.py #13595
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR addresses the TODO in orbit/controller.py to support steps=-1 in Controller.train(), allowing training to run until the underlying dataset is exhausted.
Motivation: Previously, Controller.train required a fixed number of steps. This change allows users to train for a full epoch (or until the dataset runs out) without needing to know the exact dataset size beforehand, which is common when using tf.data.Dataset.
Changes:
-Modified Controller.train loop condition to accept steps=-1.
-Added a try-except block to catch tf.errors.OutOfRangeError and StopIteration during _train_n_steps. This ensures the loop exits gracefully when the iterator is exhausted instead of crashing.
-Added logic to break the loop if the global_step increment is less than expected (another indicator of exhaustion).
-Added a new test case test_train_until_exhaustion in orbit/controller_test.py to verify this behavior using a finite dataset.
Type of change
For a new feature or function, please create an issue first to discuss it
with us before submitting a pull request.
Note: Please delete options that are not relevant.
Tests
I verified the changes by running the new test case and existing tests.
Test Configuration:
OS: Windows 11
Python Version: 3.10
Command: python -m orbit.controller_test
Result: Passed. specifically, test_train_until_exhaustion passed with the expected behavior
Checklist