Skip to content

Commit a08b78c

Browse files
committed
fix main
1 parent b5db513 commit a08b78c

File tree

4 files changed

+11
-8
lines changed

4 files changed

+11
-8
lines changed

.github/workflows/flax_test.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,20 @@ jobs:
127127
uv pip install -U git+https://github.com/google-deepmind/dm-haiku.git
128128
# temporary: install jax nightly
129129
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
130-
uv run tests/run_all_tests.sh --only-doctest
130+
uv run --no-sync tests/run_all_tests.sh --only-doctest --use-venv
131131
elif [[ "${{ matrix.test-type }}" == "pytest" ]]; then
132132
uv pip install -U tensorflow-datasets
133133
# temporary: install jax nightly
134134
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
135-
uv run tests/run_all_tests.sh --only-pytest
135+
uv run --no-sync tests/run_all_tests.sh --only-pytest --use-venv
136136
elif [[ "${{ matrix.test-type }}" == "pytype" ]]; then
137-
# temporary: install jax nightly
137+
# temporary: install jax nightly
138138
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
139-
uv run tests/run_all_tests.sh --only-pytype
139+
uv run --no-sync tests/run_all_tests.sh --only-pytype --use-venv
140140
elif [[ "${{ matrix.test-type }}" == "mypy" ]]; then
141141
# temporary: install jax nightly
142142
uv pip install -U --pre jax jaxlib -i https://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
143-
uv run tests/run_all_tests.sh --only-mypy
143+
uv run --no-sync tests/run_all_tests.sh --only-mypy --use-venv
144144
else
145145
echo "Unknown test type: ${{ matrix.test-type }}"
146146
exit 1

examples/nnx_toy_examples/hijax_demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from flax import nnx
2222

2323
# ## Data
24-
# We create a simple dataset of points sampled from a parabola with some noise.
24+
# We create a simple dataset of points sampled from a parabola with some noise.
2525
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
2626
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
2727

tests/nnx/spmd_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,11 @@ def __call__(self, x: jax.Array):
197197
def test_eager_sharding_context(self, use_eager_sharding):
198198
rngs = nnx.Rngs(0)
199199
with nnx.use_eager_sharding(use_eager_sharding):
200-
mesh = jax.make_mesh(((2, 2)), ("data", "model"))
200+
mesh = jax.make_mesh(
201+
(2, 2),
202+
('data', 'model'),
203+
axis_types=(jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto),
204+
)
201205
with jax.set_mesh(mesh):
202206
w = nnx.Param(
203207
rngs.lecun_normal()((4, 8)),

tests/run_all_tests.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ echo "jax: $(python -c 'import jax; print(jax.__version__)')"
6565
echo "flax: $(python -c 'import flax; print(flax.__version__)')"
6666
echo "=========================="
6767
echo ""
68-
6968
sh $(dirname "$0")/download_dataset_metadata.sh || exit
7069

7170
# Instead of using set -e, we have a manual error trap that

0 commit comments

Comments
 (0)