Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_nnx/api_reference/flax.nnx/nn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
linear
lora
normalization
pooling
recurrent
stochastic

10 changes: 10 additions & 0 deletions docs_nnx/api_reference/flax.nnx/nn/pooling.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
Pooling
------------------------

.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autofunction:: avg_pool
.. autofunction:: max_pool
.. autofunction:: min_pool
.. autofunction:: pool
2 changes: 1 addition & 1 deletion flax/core/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
swish as swish,
tanh as tanh,
)
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
from flax.pooling import avg_pool as avg_pool, max_pool as max_pool
from .attention import (
dot_product_attention as dot_product_attention,
multi_head_dot_product_attention as multi_head_dot_product_attention,
Expand Down
2 changes: 1 addition & 1 deletion flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
SpectralNorm as SpectralNorm,
WeightNorm as WeightNorm,
)
from .pooling import (avg_pool as avg_pool, max_pool as max_pool, pool as pool)
from ..pooling import avg_pool as avg_pool, max_pool as max_pool, pool as pool
from .recurrent import (
Bidirectional as Bidirectional,
ConvLSTMCell as ConvLSTMCell,
Expand Down
8 changes: 4 additions & 4 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.

from flax.core.spmd import logical_axis_rules as logical_axis_rules
from flax.linen.pooling import avg_pool as avg_pool
from flax.linen.pooling import max_pool as max_pool
from flax.linen.pooling import min_pool as min_pool
from flax.linen.pooling import pool as pool
from flax.pooling import avg_pool as avg_pool
from flax.pooling import max_pool as max_pool
from flax.pooling import min_pool as min_pool
from flax.pooling import pool as pool
from flax.typing import Initializer as Initializer

from .bridge import wrappers as wrappers
Expand Down
File renamed without changes.
96 changes: 0 additions & 96 deletions tests/linen/linen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,102 +40,6 @@ def check_eq(xs, ys):
)


class PoolTest(parameterized.TestCase):
def test_pool_custom_reduce(self):
x = jnp.full((1, 3, 3, 1), 2.0)
mul_reduce = lambda x, y: x * y
y = nn.pooling.pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool(self, count_include_pad):
x = jnp.full((1, 3, 3, 1), 2.0)
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.25, 0.5, 0.25],
[0.5, 1.0, 0.5],
[0.25, 0.5, 0.25],
]
).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool_no_batch(self, count_include_pad):
x = jnp.full((3, 3, 1), 2.0)
pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.25, 0.5, 0.25],
[0.5, 1.0, 0.5],
[0.25, 0.5, 0.25],
]
).reshape((3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

def test_max_pool(self):
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
pool = lambda x: nn.max_pool(x, (2, 2))
expected_y = jnp.array(
[
[4.0, 5.0],
[7.0, 8.0],
]
).reshape((1, 2, 2, 1))
y = pool(x)
np.testing.assert_allclose(y, expected_y)
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.0, 0.0, 0.0],
[0.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
]
).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool_padding_same(self, count_include_pad):
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
pool = lambda x: nn.avg_pool(
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
(1, 2, 2, 1)
)
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
(1, 2, 2, 1)
)
np.testing.assert_allclose(y, expected_y)

def test_pooling_variable_batch_dims(self):
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
y = nn.max_pool(x, (2, 2), (2, 2))

assert y.shape == (1, 8, 16, 16, 3)

def test_pooling_no_batch_dims(self):
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
y = nn.max_pool(x, (2, 2), (2, 2))

assert y.shape == (16, 16, 3)


class NormalizationTest(parameterized.TestCase):
def test_layer_norm_mask(self):
key = random.key(0)
Expand Down
107 changes: 107 additions & 0 deletions tests/pooling_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import unittest
from flax.pooling import pool, avg_pool, max_pool
import numpy as np
import jax.numpy as jnp
from absl.testing import parameterized
import jax

jax.config.parse_flags_with_absl()


class PoolTest(parameterized.TestCase):
def test_pool_custom_reduce(self):
x = jnp.full((1, 3, 3, 1), 2.0)
mul_reduce = lambda x, y: x * y
y = pool(x, 1.0, mul_reduce, (2, 2), (1, 1), 'VALID')
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0**4))

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool(self, count_include_pad):
x = jnp.full((1, 3, 3, 1), 2.0)
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.0))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.25, 0.5, 0.25],
[0.5, 1.0, 0.5],
[0.25, 0.5, 0.25],
]
).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool_no_batch(self, count_include_pad):
x = jnp.full((3, 3, 1), 2.0)
pool = lambda x: avg_pool(x, (2, 2), count_include_pad=count_include_pad)
y = pool(x)
np.testing.assert_allclose(y, np.full((2, 2, 1), 2.0))
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.25, 0.5, 0.25],
[0.5, 1.0, 0.5],
[0.25, 0.5, 0.25],
]
).reshape((3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

def test_max_pool(self):
x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
pool = lambda x: max_pool(x, (2, 2))
expected_y = jnp.array(
[
[4.0, 5.0],
[7.0, 8.0],
]
).reshape((1, 2, 2, 1))
y = pool(x)
np.testing.assert_allclose(y, expected_y)
y_grad = jax.grad(lambda x: pool(x).sum())(x)
expected_grad = jnp.array(
[
[0.0, 0.0, 0.0],
[0.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
]
).reshape((1, 3, 3, 1))
np.testing.assert_allclose(y_grad, expected_grad)

@parameterized.parameters(
{'count_include_pad': True}, {'count_include_pad': False}
)
def test_avg_pool_padding_same(self, count_include_pad):
x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
pool = lambda x: avg_pool(
x, (2, 2), padding='SAME', count_include_pad=count_include_pad
)
y = pool(x)
if count_include_pad:
expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape(
(1, 2, 2, 1)
)
else:
expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape(
(1, 2, 2, 1)
)
np.testing.assert_allclose(y, expected_y)

def test_pooling_variable_batch_dims(self):
x = jnp.zeros((1, 8, 32, 32, 3), dtype=jnp.float32)
y = max_pool(x, (2, 2), (2, 2))

assert y.shape == (1, 8, 16, 16, 3)

def test_pooling_no_batch_dims(self):
x = jnp.zeros((32, 32, 3), dtype=jnp.float32)
y = max_pool(x, (2, 2), (2, 2))

assert y.shape == (16, 16, 3)

if __name__ == '__main__':
unittest.main()
Loading