File tree Expand file tree Collapse file tree 3 files changed +19
-17
lines changed
Expand file tree Collapse file tree 3 files changed +19
-17
lines changed Original file line number Diff line number Diff line change 2020import importlib
2121import six
2222
23- modules = [
23+ MODULES = [
2424 "tensor2tensor.data_generators.algorithmic" ,
2525 "tensor2tensor.data_generators.algorithmic_math" ,
2626 "tensor2tensor.data_generators.audio" ,
6767 "tensor2tensor.data_generators.wikitext103" ,
6868 "tensor2tensor.data_generators.wsj_parsing" ,
6969]
70+ ALL_MODULES = list (MODULES )
7071
7172
7273
@@ -94,10 +95,11 @@ def _handle_errors(errors):
9495 print ("Did not import module: %s; Cause: %s" % (module , err_str ))
9596
9697
97- _errors = []
98- for _module in modules :
99- try :
100- importlib .import_module (_module )
101- except ImportError as error :
102- _errors .append ((_module , error ))
103- _handle_errors (_errors )
98+ def import_modules (modules ):
99+ errors = []
100+ for module in modules :
101+ try :
102+ importlib .import_module (module )
103+ except ImportError as error :
104+ errors .append ((module , error ))
105+ _handle_errors (errors )
Original file line number Diff line number Diff line change 1717from __future__ import division
1818from __future__ import print_function
1919
20- from tensor2tensor .data_generators import all_problems # pylint: disable=unused-import
20+ from tensor2tensor .data_generators import all_problems
2121from tensor2tensor .utils import registry
2222
2323
@@ -27,3 +27,6 @@ def problem(name):
2727
2828def available ():
2929 return sorted (registry .list_problems ())
30+
31+
32+ all_problems .import_modules (all_problems .ALL_MODULES )
Original file line number Diff line number Diff line change 1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- """Tests for Tensor2Tensor's all_problems.py."""
16-
15+ """tensor2tensor.problems test."""
1716from __future__ import absolute_import
1817from __future__ import division
1918from __future__ import print_function
2019
21- from tensor2tensor . data_generators import all_problems
20+ from tensor2tensor import problems
2221
2322import tensorflow as tf
2423
2524
26- class AllProblemsTest (tf .test .TestCase ):
25+ class ProblemsTest (tf .test .TestCase ):
2726
2827 def testImport (self ):
29- """Make sure that importing all_problems doesn't break."""
30- self .assertIsNotNone (all_problems )
31-
28+ self .assertIsNotNone (problems )
3229
33- if __name__ == ' __main__' :
30+ if __name__ == " __main__" :
3431 tf .test .main ()
You can’t perform that action at this time.
0 commit comments