Skip to content

Commit ee6a437

Browse files
Ryan SepassiCopybara-Service
authored andcommitted
Move problems import to problems.py (from all_problems.py)
PiperOrigin-RevId: 201026560
1 parent 214a3cc commit ee6a437

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

tensor2tensor/data_generators/all_problems.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import importlib
2121
import six
2222

23-
modules = [
23+
MODULES = [
2424
"tensor2tensor.data_generators.algorithmic",
2525
"tensor2tensor.data_generators.algorithmic_math",
2626
"tensor2tensor.data_generators.audio",
@@ -67,6 +67,7 @@
6767
"tensor2tensor.data_generators.wikitext103",
6868
"tensor2tensor.data_generators.wsj_parsing",
6969
]
70+
ALL_MODULES = list(MODULES)
7071

7172

7273

@@ -94,10 +95,11 @@ def _handle_errors(errors):
9495
print("Did not import module: %s; Cause: %s" % (module, err_str))
9596

9697

97-
_errors = []
98-
for _module in modules:
99-
try:
100-
importlib.import_module(_module)
101-
except ImportError as error:
102-
_errors.append((_module, error))
103-
_handle_errors(_errors)
98+
def import_modules(modules):
99+
errors = []
100+
for module in modules:
101+
try:
102+
importlib.import_module(module)
103+
except ImportError as error:
104+
errors.append((module, error))
105+
_handle_errors(errors)

tensor2tensor/problems.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from __future__ import division
1818
from __future__ import print_function
1919

20-
from tensor2tensor.data_generators import all_problems # pylint: disable=unused-import
20+
from tensor2tensor.data_generators import all_problems
2121
from tensor2tensor.utils import registry
2222

2323

@@ -27,3 +27,6 @@ def problem(name):
2727

2828
def available():
2929
return sorted(registry.list_problems())
30+
31+
32+
all_problems.import_modules(all_problems.ALL_MODULES)

tensor2tensor/data_generators/all_problems_test.py renamed to tensor2tensor/problems_test.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,20 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""Tests for Tensor2Tensor's all_problems.py."""
16-
15+
"""tensor2tensor.problems test."""
1716
from __future__ import absolute_import
1817
from __future__ import division
1918
from __future__ import print_function
2019

21-
from tensor2tensor.data_generators import all_problems
20+
from tensor2tensor import problems
2221

2322
import tensorflow as tf
2423

2524

26-
class AllProblemsTest(tf.test.TestCase):
25+
class ProblemsTest(tf.test.TestCase):
2726

2827
def testImport(self):
29-
"""Make sure that importing all_problems doesn't break."""
30-
self.assertIsNotNone(all_problems)
31-
28+
self.assertIsNotNone(problems)
3229

33-
if __name__ == '__main__':
30+
if __name__ == "__main__":
3431
tf.test.main()

0 commit comments

Comments
 (0)