diff --git a/src/zennit/composites.py b/src/zennit/composites.py index 8fd030b..9a913d0 100644 --- a/src/zennit/composites.py +++ b/src/zennit/composites.py @@ -141,8 +141,7 @@ def mapping(self, ctx, name, module): ''' return next((hook for names, hook in self.name_map if name in names), None) - -class MixedComposite(Composite): +class MultiComposite(Composite): '''A Composite for which hooks are specified by a list of composites. Each composite defines a mapping from layer property to a specific Hook. @@ -190,9 +189,46 @@ def mapping(self, ctx, name, module): hooks = [composite.module_map(ctx[composite], name, module) for composite in self.composites] # return first hook that is not None, if there isn't any, return None + return self.handle_hooks(hooks) + + def handle_hooks(self, hooks): + raise NotImplementedError() + + +class MixedComposite(Composite): + '''A Composite for which hooks are specified by a list of composites. + + Each composite defines a mapping from layer property to a specific Hook. + The list order of composites defines their matching order. + + Parameters + ---------- + composites: `list[Composite]` + A list of Composites. The list order of composites defines their matching order. + canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional + List of canonizer instances to be applied before applying hooks. + ''' + def handle_hooks(self, hooks): return next((hook for hook in hooks if hook is not None), None) +class MergedComposite(Composite): + '''A Composite for which hooks are specified by a list of composites. + + Each composite defines a mapping from layer property to a specific Hook. + The list order of composites defines their matching order. + + Parameters + ---------- + composites: `list[Composite]` + A list of Composites. The list order of composites defines their matching order. + canonizers: list[:py:class:`zennit.canonizers.Canonizer`], optional + List of canonizer instances to be applied before applying hooks. + ''' + def handle_hooks(self, hooks): + return tuple(hook for hook in hooks if hook is not None) or None + + class NameLayerMapComposite(MixedComposite): '''A Composite for which hooks are specified by both a mapping from module names and module types to hooks. diff --git a/src/zennit/core.py b/src/zennit/core.py index 0bf2e6f..721d9aa 100644 --- a/src/zennit/core.py +++ b/src/zennit/core.py @@ -369,7 +369,70 @@ def backward(ctx, *grad_outputs): return grad_outputs -class Hook: +class HookBase: + '''Base for Hook functionality. Every hook must implement this interface.''' + def register(self, module): + '''Attach this hook to a module. This modifies forward/backward computations. Returns a handle which can be + used to call this Hook's ``.remove``. + ''' + return RemovableHandle(self) + + def remove(self): + '''Remove this hook. Removes all references and modifications it introduced.''' + + def copy(self): + '''Return a copy of this hook. + This is used to describe hooks of different modules by a single hook instance. + ''' + return self.__class__() + + +class GradOutHook(HookBase): + '''Hook to only modify the output gradient of a module. This leaves the gradient computation of the module intact. + ''' + def post_forward(self, module, input, output): + '''Register a backward-hook to the resulting tensor right after the forward.''' + hook_ref = weakref.ref(self) + + @functools.wraps(self.backward) + def wrapper(grad_input, grad_output): + hook = hook_ref() + if hook is not None and hook.active: + return hook.backward(module, grad_output) + return None + + if not isinstance(output, tuple): + output = (output,) + + # only if gradient required + if output[0].grad_fn is not None: + # add identity to ensure .grad_fn exists + post_output = Identity.apply(*output) + # register the input tensor gradient hook + self.tensor_handles.append( + post_output[0].grad_fn.register_hook(wrapper) + ) + # work around to support in-place operations + post_output = tuple(elem.clone() for elem in post_output) + else: + # no gradient required + post_output = output + return post_output[0] if len(post_output) == 1 else post_output + + def backward(self, module, grad_output): + '''Hook applied during backward-pass. Modifies the output gradient of module before its gradient + computation. + ''' + + def register(self, module): + '''Register this instance by registering the neccessary forward hook to the supplied module.''' + return RemovableHandleList([ + RemovableHandle(self), + module.register_forward_hook(self.post_forward), + ]) + + +class Hook(HookBase): '''Base class for hooks to be used to compute layer-wise attributions.''' def __init__(self): self.stored_tensors = {} @@ -659,11 +722,17 @@ def register(self, module): ctx = {} for name, child in module.named_modules(): - template = self.module_map(ctx, name, child) - if template is not None: - hook = template.copy() - self.hook_refs.add(hook) - self.handles.append(hook.register(child)) + templates = self.module_map(ctx, name, child) + try: + templates = iter(template) + else: + templates = (template,) + + for template in templates: + if template is not None: + hook = template.copy() + self.hook_refs.add(hook) + self.handles.append(hook.register(child)) def remove(self): '''Remove all handles for hooks and canonizers.