Skip to content

Commit

Permalink
Add flax_field and tree_map_with_path
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Oct 17, 2023
1 parent a29eaa4 commit 0ee0ae8
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 7 deletions.
8 changes: 5 additions & 3 deletions tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
replace_cotangent, scale_cotangent)
from ._src.display import display_generic, print_generic, tapped_print_generic
from ._src.dtypes import default_atol, default_rtol, default_tols
from ._src.flax_tools import flax_field
from ._src.leaky_integral import (diffused_leaky_integrate, leaky_covariance, leaky_data_weight,
leaky_integrate, leaky_integrate_time_series)
from ._src.math_tools import (abs_square, divide_nonnegative, divide_where, inverse_softplus,
is_scalar, zero_tangent_like)
from ._src.partial import Partial
from ._src.shims import custom_jvp, custom_jvp_method, custom_vjp, custom_vjp_method, jit
from ._src.testing import (assert_tree_allclose, get_relative_test_string, get_test_string,
tree_allclose)
from ._src.tools import (abs_square, divide_nonnegative, divide_where, inverse_softplus, is_scalar,
zero_tangent_like)
from ._src.tree_tools import tree_map_with_path

__all__ = ['BooleanArray', 'BooleanNumeric', 'Complex', 'ComplexArray', 'ComplexNumeric', 'Array',
'KeyArray', 'Integral', 'IntegralArray', 'IntegralNumeric', 'NumpyArray',
Expand All @@ -40,7 +42,7 @@
'cotangent_combinator', 'print_cotangent', 'print_generic', 'replace_cotangent',
'JaxArray', 'JaxBooleanArray', 'JaxIntegralArray', 'tree_allclose', 'zero_tangent_like',
'JaxAbstractClass', 'JaxComplexArray', 'JaxRealArray', 'abstract_jit',
'abstract_custom_jvp']
'abstract_custom_jvp', 'tree_map_with_path', 'flax_field']
#
# __pdoc__ = {}
# __pdoc__['PyTreeLike'] = False
Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .dtypes import *
from .graph import *
from .leaky_integral import *
from .math_tools import *
from .partial import *
from .shims import *
from .testing import *
from .tools import *
2 changes: 1 addition & 1 deletion tjax/_src/fixed_point/comparing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..annotations import JaxBooleanArray, JaxRealArray
from ..dataclasses import dataclass
from ..tools import divide_nonnegative
from ..math_tools import divide_nonnegative
from .augmented import AugmentedState, State
from .iterated_function import Comparand, IteratedFunction, Parameters, Trajectory

Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/fixed_point/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..annotations import ComplexNumeric, JaxBooleanArray, JaxComplexArray, JaxRealArray
from ..dataclasses import dataclass
from ..leaky_integral import leaky_data_weight, leaky_integrate
from ..tools import abs_square, divide_nonnegative
from ..math_tools import abs_square, divide_nonnegative
from .augmented import AugmentedState, State
from .combinator import Differentiand, IteratedFunctionWithCombinator
from .iterated_function import Comparand, IteratedFunction, Parameters, Trajectory
Expand Down
10 changes: 10 additions & 0 deletions tjax/_src/flax_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from __future__ import annotations

from dataclasses import field
from typing import Any

__all__ = ['flax_field']


def flax_field() -> Any:
return field(init=False, default=None, kw_only=True)
2 changes: 1 addition & 1 deletion tjax/_src/gradient/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ..annotations import PyTree
from ..dataclasses import dataclass
from ..tools import abs_square
from ..math_tools import abs_square

__all__ = ['GradientState', 'GenericGradientState', 'GradientTransformation',
'SecondOrderGradientTransformation', 'ThirdOrderGradientTransformation']
Expand Down
File renamed without changes.
23 changes: 23 additions & 0 deletions tjax/_src/tree_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Any, TypeVar

__all__ = ['tree_map_with_path']


T = TypeVar('T')
def tree_map_with_path(structure: Any,
transform: Callable[[T, tuple[str, ...]], T]
) -> Any:
def f(structure: Any,
transform: Callable[[T, tuple[str, ...]], T],
path: tuple[str, ...]
) -> Any:
if isinstance(structure, dict):
out_structure = {}
for key, value in structure.items():
out_structure[key] = f(value, transform, (*path, key))
return out_structure
return transform(structure, path)
return f(structure, transform, ())

0 comments on commit 0ee0ae8

Please sign in to comment.