diff --git a/docs_nnx/api_reference/flax.nnx/nn/index.rst b/docs_nnx/api_reference/flax.nnx/nn/index.rst index e42d58428..cf38ae073 100644 --- a/docs_nnx/api_reference/flax.nnx/nn/index.rst +++ b/docs_nnx/api_reference/flax.nnx/nn/index.rst @@ -14,6 +14,7 @@ See the `NNX page `__ for linear lora normalization + pooling recurrent stochastic diff --git a/docs_nnx/api_reference/flax.nnx/nn/pooling.rst b/docs_nnx/api_reference/flax.nnx/nn/pooling.rst new file mode 100644 index 000000000..5d93b8be1 --- /dev/null +++ b/docs_nnx/api_reference/flax.nnx/nn/pooling.rst @@ -0,0 +1,10 @@ +Pooling +------------------------ + +.. automodule:: flax.nnx +.. currentmodule:: flax.nnx + +.. autofunction:: avg_pool +.. autofunction:: max_pool +.. autofunction:: min_pool +.. autofunction:: pool \ No newline at end of file diff --git a/flax/core/nn/__init__.py b/flax/core/nn/__init__.py index b26ae531c..9581864b7 100644 --- a/flax/core/nn/__init__.py +++ b/flax/core/nn/__init__.py @@ -35,7 +35,7 @@ swish as swish, tanh as tanh, ) -from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool) +from flax.pooling import avg_pool as avg_pool, max_pool as max_pool from .attention import ( dot_product_attention as dot_product_attention, multi_head_dot_product_attention as multi_head_dot_product_attention, diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index c15bb8424..33caf0ac3 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -123,7 +123,7 @@ SpectralNorm as SpectralNorm, WeightNorm as WeightNorm, ) -from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool) +from ..pooling import avg_pool as avg_pool, max_pool as max_pool, pool as pool from .recurrent import ( Bidirectional as Bidirectional, ConvLSTMCell as ConvLSTMCell, diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 39e7c94f0..2bcf2bbc0 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -13,10 +13,10 @@ # limitations under the License. from flax.core.spmd import logical_axis_rules as logical_axis_rules -from flax.linen.pooling import avg_pool as avg_pool -from flax.linen.pooling import max_pool as max_pool -from flax.linen.pooling import min_pool as min_pool -from flax.linen.pooling import pool as pool +from flax.pooling import avg_pool as avg_pool +from flax.pooling import max_pool as max_pool +from flax.pooling import min_pool as min_pool +from flax.pooling import pool as pool from flax.typing import Initializer as Initializer from .bridge import wrappers as wrappers diff --git a/flax/linen/pooling.py b/flax/pooling.py similarity index 100% rename from flax/linen/pooling.py rename to flax/pooling.py diff --git a/tests/linen/linen_test.py b/tests/linen/linen_test.py index 8e08d5a02..4e943cc31 100644 --- a/tests/linen/linen_test.py +++ b/tests/linen/linen_test.py @@ -40,102 +40,6 @@ def check_eq(xs, ys): ) -class PoolTest(parameterized.TestCase): - def test_pool_custom_reduce(self): - x = jnp.full((1, 3, 3, 1), 2.0) - mul_reduce = lambda x, y: x * y - y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID') - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4)) - - @parameterized.parameters( - {'count_include_pad': True}, {'count_include_pad': False} - ) - def test_avg_pool(self, count_include_pad): - x = jnp.full((1, 3, 3, 1), 2.0) - pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) - y = pool(x) - np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0)) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array( - [ - [0.25, 0.5, 0.25], - [0.5, 1.0, 0.5], - [0.25, 0.5, 0.25], - ] - ).reshape((1, 3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - - @parameterized.parameters( - {'count_include_pad': True}, {'count_include_pad': False} - ) - def test_avg_pool_no_batch(self, count_include_pad): - x = jnp.full((3, 3, 1), 2.0) - pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad) - y = pool(x) - np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0)) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array( - [ - [0.25, 0.5, 0.25], - [0.5, 1.0, 0.5], - [0.25, 0.5, 0.25], - ] - ).reshape((3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - - def test_max_pool(self): - x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) - pool = lambda x: nn.max_pool(x, (2, 2)) - expected_y = jnp.array( - [ - [4.0, 5.0], - [7.0, 8.0], - ] - ).reshape((1, 2, 2, 1)) - y = pool(x) - np.testing.assert_allclose(y, expected_y) - y_grad = jax.grad(lambda x: pool(x).sum())(x) - expected_grad = jnp.array( - [ - [0.0, 0.0, 0.0], - [0.0, 1.0, 1.0], - [0.0, 1.0, 1.0], - ] - ).reshape((1, 3, 3, 1)) - np.testing.assert_allclose(y_grad, expected_grad) - - @parameterized.parameters( - {'count_include_pad': True}, {'count_include_pad': False} - ) - def test_avg_pool_padding_same(self, count_include_pad): - x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) - pool = lambda x: nn.avg_pool( - x, (2, 2), padding='SAME', count_include_pad=count_include_pad - ) - y = pool(x) - if count_include_pad: - expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( - (1, 2, 2, 1) - ) - else: - expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( - (1, 2, 2, 1) - ) - np.testing.assert_allclose(y, expected_y) - - def test_pooling_variable_batch_dims(self): - x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32) - y = nn.max_pool(x, (2, 2), (2, 2)) - - assert y.shape == (1, 8, 16, 16, 3) - - def test_pooling_no_batch_dims(self): - x = jnp.zeros((32, 32, 3), dtype=jnp.float32) - y = nn.max_pool(x, (2, 2), (2, 2)) - - assert y.shape == (16, 16, 3) - - class NormalizationTest(parameterized.TestCase): def test_layer_norm_mask(self): key = random.key(0) diff --git a/tests/pooling_test.py b/tests/pooling_test.py new file mode 100644 index 000000000..3395d841e --- /dev/null +++ b/tests/pooling_test.py @@ -0,0 +1,107 @@ +import unittest +from flax.pooling import pool, avg_pool, max_pool +import numpy as np +import jax.numpy as jnp +from absl.testing import parameterized +import jax + +jax.config.parse_flags_with_absl() + + +class PoolTest(parameterized.TestCase): + def test_pool_custom_reduce(self): + x = jnp.full((1, 3, 3, 1), 2.0) + mul_reduce = lambda x, y: x * y + y = pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID') + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4)) + + @parameterized.parameters( + {'count_include_pad': True}, {'count_include_pad': False} + ) + def test_avg_pool(self, count_include_pad): + x = jnp.full((1, 3, 3, 1), 2.0) + pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad) + y = pool(x) + np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0)) + y_grad = jax.grad(lambda x: pool(x).sum())(x) + expected_grad = jnp.array( + [ + [0.25, 0.5, 0.25], + [0.5, 1.0, 0.5], + [0.25, 0.5, 0.25], + ] + ).reshape((1, 3, 3, 1)) + np.testing.assert_allclose(y_grad, expected_grad) + + @parameterized.parameters( + {'count_include_pad': True}, {'count_include_pad': False} + ) + def test_avg_pool_no_batch(self, count_include_pad): + x = jnp.full((3, 3, 1), 2.0) + pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad) + y = pool(x) + np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0)) + y_grad = jax.grad(lambda x: pool(x).sum())(x) + expected_grad = jnp.array( + [ + [0.25, 0.5, 0.25], + [0.5, 1.0, 0.5], + [0.25, 0.5, 0.25], + ] + ).reshape((3, 3, 1)) + np.testing.assert_allclose(y_grad, expected_grad) + + def test_max_pool(self): + x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) + pool = lambda x: max_pool(x, (2, 2)) + expected_y = jnp.array( + [ + [4.0, 5.0], + [7.0, 8.0], + ] + ).reshape((1, 2, 2, 1)) + y = pool(x) + np.testing.assert_allclose(y, expected_y) + y_grad = jax.grad(lambda x: pool(x).sum())(x) + expected_grad = jnp.array( + [ + [0.0, 0.0, 0.0], + [0.0, 1.0, 1.0], + [0.0, 1.0, 1.0], + ] + ).reshape((1, 3, 3, 1)) + np.testing.assert_allclose(y_grad, expected_grad) + + @parameterized.parameters( + {'count_include_pad': True}, {'count_include_pad': False} + ) + def test_avg_pool_padding_same(self, count_include_pad): + x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1)) + pool = lambda x: avg_pool( + x, (2, 2), padding='SAME', count_include_pad=count_include_pad + ) + y = pool(x) + if count_include_pad: + expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape( + (1, 2, 2, 1) + ) + else: + expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape( + (1, 2, 2, 1) + ) + np.testing.assert_allclose(y, expected_y) + + def test_pooling_variable_batch_dims(self): + x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32) + y = max_pool(x, (2, 2), (2, 2)) + + assert y.shape == (1, 8, 16, 16, 3) + + def test_pooling_no_batch_dims(self): + x = jnp.zeros((32, 32, 3), dtype=jnp.float32) + y = max_pool(x, (2, 2), (2, 2)) + + assert y.shape == (16, 16, 3) + +if __name__ == '__main__': + unittest.main()