Skip to content

Commit af07a6c

Browse files
committed
Make LogEncoder more robust
1 parent 3492f44 commit af07a6c

File tree

3 files changed

+192
-89
lines changed

3 files changed

+192
-89
lines changed

.pylintrc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ disable =
3434
no-else-return,
3535
# Discourages small interfaces
3636
too-few-public-methods,
37+
# Too much old code
38+
consider-using-f-string,
3739

3840
[REPORTS]
3941
msg-template = {path}:{line:3d},{column}: {msg} ({symbol})

src/garage/experiment/experiment.py

Lines changed: 131 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,29 @@ def dump_json(filename, data):
455455
filename(str): Filename for the file.
456456
data(dict): Data to save to file.
457457
458+
Raises:
459+
KeyboardInterrupt: If the user issued a KeyboardInterrupt.
460+
458461
"""
459462
pathlib.Path(os.path.dirname(filename)).mkdir(parents=True, exist_ok=True)
460-
with open(filename, 'w') as f:
463+
with open(filename, 'w', encoding='utf-8') as f:
461464
# We do our own circular reference handling.
462465
# Sometimes sort_keys fails because the keys don't get made into
463466
# strings early enough.
464-
json.dump(data,
465-
f,
466-
indent=2,
467-
sort_keys=False,
468-
cls=LogEncoder,
469-
check_circular=False)
467+
# This feature is useful, but causes way too many weird errors.
468+
# For this reason we catch almost any exception.
469+
# pylint: disable=broad-except
470+
try:
471+
json.dump(data,
472+
f,
473+
indent=2,
474+
sort_keys=False,
475+
cls=LogEncoder,
476+
check_circular=False)
477+
except KeyboardInterrupt as e:
478+
raise e
479+
except Exception:
480+
pass
470481

471482

472483
def get_metadata():
@@ -579,44 +590,80 @@ def __init__(self, *args, **kwargs):
579590
'itertools',
580591
}
581592

582-
def default(self, o):
593+
def default(self, o, path=''):
583594
"""Perform JSON encoding.
584595
585596
Args:
586597
o (object): Object to encode.
587-
588-
Raises:
589-
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
598+
path (str): "Path" to o for describing backreferences.
590599
591600
Returns:
592601
dict or str or float or bool: Object encoded in JSON.
593602
594603
"""
595-
# Why is this method hidden? What does that mean?
596-
# pylint: disable=method-hidden
597-
# pylint: disable=too-many-branches
598-
# pylint: disable=too-many-return-statements
599604
# This circular reference checking code was copied from the standard
600-
# library json implementation, but it outputs a repr'd string instead
601-
# of ValueError on a circular reference.
602-
if isinstance(o, (int, bool, float, str)):
605+
# library json implementation, but it detects all backreferences
606+
# instead of just circular ones and outputs a message with the path to
607+
# the prior reference instead of raising a ValueError
608+
if isinstance(o, (int, bool, float, str, type(None))):
603609
return o
604610
else:
605611
markerid = id(o)
606612
if markerid in self._markers:
607-
return 'circular ' + repr(o)
613+
original_path = self._markers[markerid]
614+
return f'reference to {original_path}'
608615
else:
609-
self._markers[markerid] = o
610-
try:
611-
return self._default_inner(o)
612-
finally:
613-
del self._markers[markerid]
616+
self._markers[markerid] = path
617+
return self._default_general_cases(o, path)
614618

615-
def _default_inner(self, o):
616-
"""Perform JSON encoding.
619+
def _default_general_cases(self, o, path):
620+
"""Handle JSON encoding among the "general" cases.
621+
622+
First, tries to just use the default encoder, then various special
623+
cases, then by turning o into a list, then by calling repr, and finally
624+
by returning the string '$invalid'
617625
618626
Args:
619627
o (object): Object to encode.
628+
path (str): "Path" to o for describing backreferences.
629+
630+
Returns:
631+
dict or str or float or bool: Object encoded in JSON.
632+
633+
"""
634+
try:
635+
return json.JSONEncoder.default(self, o)
636+
except TypeError:
637+
pass
638+
try:
639+
return self._default_special_cases(o, path)
640+
except TypeError:
641+
pass
642+
try:
643+
# This case handles many built-in datatypes like deques
644+
return [
645+
self.default(v, f'{path}/{i}') for i, v in enumerate(list(o))
646+
]
647+
except TypeError:
648+
pass
649+
try:
650+
# This case handles most other weird objects.
651+
return repr(o)
652+
except TypeError:
653+
pass
654+
except ValueError:
655+
pass
656+
return '$invalid'
657+
658+
def _default_special_cases(self, o, path):
659+
"""Handle various special cases we frequently want to JSON encode.
660+
661+
Note that these cases aren't _that_ special, and include dicts, enums,
662+
np.numbers, etc.
663+
664+
Args:
665+
o (object): Object to encode.
666+
path (str): "Path" to o for describing backreferences.
620667
621668
Raises:
622669
TypeError: If `o` cannot be turned into JSON even using `repr(o)`.
@@ -626,70 +673,65 @@ def _default_inner(self, o):
626673
dict or str or float or bool: Object encoded in JSON.
627674
628675
"""
629-
# Why is this method hidden? What does that mean?
630-
# pylint: disable=method-hidden
631676
# pylint: disable=too-many-branches
632677
# pylint: disable=too-many-return-statements
633-
# This circular reference checking code was copied from the standard
634-
# library json implementation, but it outputs a repr'd string instead
635-
# of ValueError on a circular reference.
636-
try:
637-
return json.JSONEncoder.default(self, o)
638-
except TypeError as err:
639-
if isinstance(o, dict):
640-
data = {}
641-
for (k, v) in o.items():
642-
if isinstance(k, str):
643-
data[k] = self.default(v)
644-
else:
645-
data[repr(k)] = self.default(v)
646-
return data
647-
elif isinstance(o, weakref.ref):
648-
return repr(o)
649-
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
650-
return repr(o)
651-
elif isinstance(o, type):
652-
return {'$typename': o.__module__ + '.' + o.__name__}
653-
elif isinstance(o, np.number):
654-
# For some reason these aren't natively considered
655-
# serializable.
656-
# JSON doesn't actually have ints, so always use a float.
657-
return float(o)
658-
elif isinstance(o, np.bool8):
659-
return bool(o)
660-
elif isinstance(o, enum.Enum):
661-
return {
662-
'$enum':
663-
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
664-
}
665-
elif isinstance(o, np.ndarray):
666-
return repr(o)
667-
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
668-
obj_dict = getattr(o, '__dict__', None)
669-
if obj_dict is not None:
670-
data = {k: self.default(v) for (k, v) in obj_dict.items()}
671-
else:
672-
data = {
673-
s: self.default(getattr(o, s))
674-
for s in o.__slots__
675-
}
676-
t = type(o)
677-
data['$type'] = t.__module__ + '.' + t.__name__
678-
return data
679-
elif callable(o) and hasattr(o, '__name__'):
680-
if getattr(o, '__module__', None) is not None:
681-
return {'$function': o.__module__ + '.' + o.__name__}
678+
if isinstance(o, dict):
679+
data = {}
680+
for (k, v) in o.items():
681+
if isinstance(k, str):
682+
data[k] = self.default(v, f'{path}/{k}')
682683
else:
683-
return repr(o)
684+
data[repr(k)] = self.default(v, f'{path}/{k!r}')
685+
return data
686+
elif isinstance(o, weakref.ref):
687+
return repr(o)
688+
elif type(o).__module__.split('.')[0] in self.BLOCKED_MODULES:
689+
return repr(o)
690+
elif isinstance(o, type):
691+
return {'$typename': o.__module__ + '.' + o.__name__}
692+
elif isinstance(o, np.bool8):
693+
return bool(o)
694+
elif isinstance(o, np.number):
695+
# For some reason these aren't natively considered
696+
# serializable.
697+
# JSON doesn't actually have ints, so always use a float.
698+
# Some strange numpy "number" types can actually be None,
699+
# so this case can actually fail as well, which will then fall back
700+
# to one of the general cases.
701+
return float(o)
702+
elif isinstance(o, enum.Enum):
703+
return {
704+
'$enum':
705+
o.__module__ + '.' + o.__class__.__name__ + '.' + o.name
706+
}
707+
elif isinstance(o, np.ndarray):
708+
return repr(o)
709+
elif hasattr(o, '__dict__') or hasattr(o, '__slots__'):
710+
obj_dict = getattr(o, '__dict__', None)
711+
if obj_dict is not None:
712+
# Some objects will change their fields while being
713+
# iterated over, so make a copy of their dictionary.
714+
obj_dict = obj_dict.copy()
715+
data = {
716+
k: self.default(v, f'{path}/{k}')
717+
for (k, v) in obj_dict.items()
718+
# There's a lot of spam from empty dict / list fields
719+
# The output of this JSONEncoder is not intended to be
720+
# loaded back into the original objects anyways.
721+
if not isinstance(v, (list, dict, set, tuple)) or v
722+
}
684723
else:
685-
try:
686-
# This case handles many built-in datatypes like deques
687-
return [self.default(v) for v in list(o)]
688-
except TypeError:
689-
pass
690-
try:
691-
# This case handles most other weird objects.
692-
return repr(o)
693-
except TypeError:
694-
pass
695-
raise err
724+
data = {
725+
s: self.default(getattr(o, s), f'{path}/{s}')
726+
for s in o.__slots__
727+
}
728+
t = type(o)
729+
data['$type'] = t.__module__ + '.' + t.__name__
730+
return data
731+
elif callable(o) and hasattr(o, '__name__'):
732+
if getattr(o, '__module__', None) is not None:
733+
return {'$function': o.__module__ + '.' + o.__name__}
734+
else:
735+
return repr(o)
736+
else:
737+
raise TypeError('Could not JSON encode object')
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
3+
import numpy as np
4+
import torch
5+
6+
from garage.envs import GymEnv, normalize
7+
from garage.experiment import deterministic
8+
from garage.experiment.experiment import LogEncoder
9+
from garage.plotter import Plotter
10+
from garage.sampler import LocalSampler
11+
from garage.torch.algos import PPO
12+
from garage.torch.policies import GaussianMLPPolicy
13+
from garage.torch.value_functions import GaussianMLPValueFunction
14+
from garage.trainer import Trainer
15+
16+
from tests.fixtures import snapshot_config
17+
18+
19+
def test_encode_none_timedelta64():
20+
value = {'test': np.timedelta64(None)}
21+
encoded = json.dumps(value,
22+
indent=2,
23+
sort_keys=False,
24+
cls=LogEncoder,
25+
check_circular=False)
26+
assert 'test' in encoded
27+
28+
29+
def test_encode_trainer():
30+
env = normalize(GymEnv('InvertedDoublePendulum-v2'))
31+
policy = GaussianMLPPolicy(
32+
env_spec=env.spec,
33+
hidden_sizes=(64, 64),
34+
hidden_nonlinearity=torch.tanh,
35+
output_nonlinearity=None,
36+
)
37+
value_function = GaussianMLPValueFunction(env_spec=env.spec)
38+
sampler = LocalSampler(agents=policy,
39+
envs=env,
40+
max_episode_length=env.spec.max_episode_length,
41+
is_tf_worker=False)
42+
43+
trainer = Trainer(snapshot_config)
44+
algo = PPO(env_spec=env.spec,
45+
policy=policy,
46+
value_function=value_function,
47+
sampler=sampler,
48+
discount=0.99,
49+
gae_lambda=0.97,
50+
lr_clip_range=2e-1)
51+
52+
trainer.setup(algo, env)
53+
encoded = json.dumps(trainer,
54+
indent=2,
55+
sort_keys=False,
56+
cls=LogEncoder,
57+
check_circular=False)
58+
print(encoded)
59+
assert 'value_function' in encoded

0 commit comments

Comments
 (0)