Skip to content

Commit bf8ed5c

Browse files
authored
Merge pull request #212 from esa/Release
release -> main for 0.4.1
2 parents 87da956 + 925d097 commit bf8ed5c

13 files changed

+88
-68
lines changed

.github/workflows/deploy_to_pypi.yml

+17-21
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,25 @@
1-
# This workflows will upload a Python Package using Twine when a release is created
2-
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3-
4-
name: Upload Python Package to pypi
1+
name: Upload Python Package to PyPI
52

63
on: workflow_dispatch
74

85
jobs:
96
deploy:
10-
117
runs-on: ubuntu-latest
128

139
steps:
14-
- uses: actions/checkout@v2
15-
- name: Set up Python
16-
uses: actions/setup-python@v2
17-
with:
18-
python-version: '3.8'
19-
- name: Install dependencies
20-
run: |
21-
pip install setuptools wheel twine
22-
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
23-
- name: Build and publish
24-
env:
25-
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
26-
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
27-
run: |
28-
python setup.py sdist bdist_wheel
29-
twine upload dist/*
10+
- uses: actions/checkout@v3
11+
- name: Set up Python
12+
uses: actions/setup-python@v4
13+
with:
14+
python-version: "3.10"
15+
- name: Install dependencies
16+
run: |
17+
pip install setuptools wheel twine
18+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
19+
- name: Build and publish to PyPI
20+
env:
21+
TWINE_USERNAME: "__token__"
22+
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
23+
run: |
24+
python setup.py sdist bdist_wheel
25+
twine upload dist/*
+8-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
# This workflows will upload a Python Package using Twine when a release is created
2-
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3-
4-
name: Upload Python Package to testpypi
1+
name: Upload Python Package to Test PyPI
52

63
on:
74
workflow_dispatch:
@@ -13,19 +10,19 @@ jobs:
1310
runs-on: ubuntu-latest
1411

1512
steps:
16-
- uses: actions/checkout@v2
13+
- uses: actions/checkout@v3
1714
- name: Set up Python
18-
uses: actions/setup-python@v2
15+
uses: actions/setup-python@v4
1916
with:
20-
python-version: 3.8
17+
python-version: "3.10"
2118
- name: Install dependencies
2219
run: |
2320
pip install setuptools wheel twine
2421
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
25-
- name: Build and publish
22+
- name: Build and publish to Test PyPI
2623
env:
27-
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28-
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
24+
TWINE_USERNAME: "__token__"
25+
TWINE_PASSWORD: ${{ secrets.TEST_PYPI_TOKEN }}
2926
run: |
3027
python setup.py sdist bdist_wheel
31-
twine upload -r testpypi dist/*
28+
twine upload -r testpypi dist/*

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Note also that installing PyTorch with *pip* may **not** set it up with CUDA sup
125125
Here are installation instructions for other numerical backends:
126126
```sh
127127
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
128-
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
128+
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
129129
conda install "numpy>=1.19.5" -c conda-forge
130130
```
131131

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
author = "ESA Advanced Concepts Team"
2323

2424
# The full version, including alpha/beta/rc tags
25-
release = "0.4.0"
25+
release = "0.4.1"
2626

2727

2828
# -- General configuration ---------------------------------------------------

docs/source/install.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Here are installation instructions for other numerical backends:
5757
.. code-block:: bash
5858
5959
conda install "tensorflow>=2.6.0=cuda*" -c conda-forge
60-
pip install "jax[cuda]>=0.2.22" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
60+
pip install "jax[cuda]>=0.4.17" --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # linux only
6161
conda install "numpy>=1.19.5" -c conda-forge
6262
6363
More installation instructions for numerical backends can be found in

environment_all_backends.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
- loguru>=0.5.3
88
- matplotlib>=3.3.3
99
- pytest>=6.2.1
10-
- python>=3.8
10+
- python==3.12
1111
- scipy>=1.6.0
1212
- sphinx>=3.4.3
1313
- sphinx_rtd_theme>=0.5.1
@@ -16,9 +16,9 @@ dependencies:
1616
- numpy>=1.19.5
1717
- cudatoolkit>=11.1
1818
- pytorch>=1.9 # CPU version
19-
- tensorflow>=2.10.0 # CPU version
2019
# jaxlib with CUDA support is not available for conda
2120
- pip:
2221
- --find-links https://storage.googleapis.com/jax-releases/jax_releases.html
23-
- jax[cpu]>=0.2.22 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
22+
- tensorflow>=2.18.0 # CPU version
23+
- jax[cpu]>=0.4.17 # this will only work on linux. for win see e.g. https://github.com/cloudhan/jax-windows-builder
2424
# CPU version

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
setup(
1111
name="torchquad",
12-
version="0.4.0",
12+
version="0.4.1",
1313
description="Package providing torch-based numerical integration methods.",
1414
long_description=open("README.md").read(),
1515
long_description_content_type="text/markdown",

torchquad/integration/base_integrator.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,13 @@ def evaluate_integrand(fn, points, weights=None, args=None):
8383
len(result.shape) > 1
8484
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
8585
integrand_shape = anp.array(
86-
result.shape[1:], like=infer_backend(points)
86+
[
87+
dim if isinstance(dim, int) else dim.as_list()
88+
for dim in result.shape[1:]
89+
],
90+
like=infer_backend(points),
8791
)
92+
8893
weights = anp.repeat(
8994
anp.expand_dims(weights, axis=1), anp.prod(integrand_shape)
9095
).reshape((weights.shape[0], *(integrand_shape)))

torchquad/integration/utils.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,21 @@ def _setup_integration_domain(dim, integration_domain, backend):
139139
# Get a globally default backend
140140
backend = _get_default_backend()
141141
dtype_arg = _get_precision(backend)
142-
if dtype_arg is not None:
143-
# For NumPy and Tensorflow there is no global dtype, so set the
144-
# configured default dtype here
145-
integration_domain = anp.array(
146-
integration_domain, like=backend, dtype=dtype_arg
147-
)
148-
else:
149-
integration_domain = anp.array(integration_domain, like=backend)
142+
if backend == "tensorflow":
143+
import tensorflow as tf
144+
145+
dtype_arg = dtype_arg or tf.keras.backend.floatx()
146+
147+
integration_domain = anp.array(
148+
integration_domain, like=backend, dtype=dtype_arg
149+
)
150150

151151
if integration_domain.shape != (dim, 2):
152152
raise ValueError(
153153
"The integration domain has an unexpected shape. "
154154
f"Expected {(dim, 2)}, got {integration_domain.shape}"
155155
)
156+
156157
return integration_domain
157158

158159

torchquad/tests/integration_test_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def _poly(self, x):
214214
# Tensorflow does not automatically cast float32 to complex128,
215215
# so we do it here explicitly.
216216
assert self.is_complex
217-
exponentials = anp.cast(exponentials, self.coeffs.dtype)
217+
exponentials = exponentials.astype(self.coeffs.dtype)
218218

219219
# multiply by coefficients
220220
exponentials = anp.multiply(exponentials, self.coeffs)

torchquad/tests/integrator_types_test.py

+21-6
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,25 @@ def fn_const(x):
5656
("jax", "float64", "float32"),
5757
]:
5858
continue
59+
5960
integrator_name = type(integrator).__name__
61+
6062
# VEGAS supports only numpy and torch
6163
if integrator_name == "VEGAS" and backend in ["jax", "tensorflow"]:
6264
continue
6365

6466
# Set the global precision
6567
set_precision(dtype_global, backend=backend)
6668

69+
# Determine expected dtype
70+
if backend == "tensorflow":
71+
import tensorflow as tf
72+
73+
expected_dtype_name = dtype_arg if dtype_arg else tf.keras.backend.floatx()
74+
else:
75+
expected_dtype_name = dtype_arg if dtype_arg else dtype_global
76+
77+
# Set integration domain
6778
integration_domain = [[0.0, 1.0], [-2.0, 0.0]]
6879
if dtype_arg is not None:
6980
# Set the integration_domain dtype which should have higher priority
@@ -75,18 +86,18 @@ def fn_const(x):
7586
)
7687
assert infer_backend(integration_domain) == backend
7788
assert get_dtype_name(integration_domain) == dtype_arg
78-
expected_dtype_name = dtype_arg
79-
else:
80-
expected_dtype_name = dtype_global
8189

8290
print(
83-
f"[2mTesting {integrator_name} with {backend}, argument dtype"
84-
f" {dtype_arg}, global/default dtype {dtype_global}[m"
91+
f"Testing {integrator_name} with {backend}, argument dtype"
92+
f" {dtype_arg}, global/default dtype {dtype_global}"
8593
)
94+
95+
# Integration
8696
if integrator_name in ["MonteCarlo", "VEGAS"]:
8797
extra_kwargs = {"seed": 0}
8898
else:
8999
extra_kwargs = {}
100+
90101
result = integrator.integrate(
91102
fn=fn_const,
92103
dim=2,
@@ -95,8 +106,12 @@ def fn_const(x):
95106
backend=backend,
96107
**extra_kwargs,
97108
)
109+
98110
assert infer_backend(result) == backend
99-
assert get_dtype_name(result) == expected_dtype_name
111+
assert (
112+
get_dtype_name(result) == expected_dtype_name
113+
), f"Expected dtype {expected_dtype_name}, got {get_dtype_name(result)}"
114+
100115
# VEGAS seems to be bad at integrating constant functions currently
101116
max_error = 0.03 if integrator_name == "VEGAS" else 1e-5
102117
assert anp.abs(result - (-4.0)) < max_error

torchquad/tests/monte_carlo_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def _run_monte_carlo_tests(backend, _precision):
4040
assert errors[4] < 32.0
4141

4242
for error in errors[6:10]:
43-
assert error < 1e-2
43+
assert error < 1.1e-2
4444

4545
for error in errors[10:]:
4646
assert error < 28.03

torchquad/utils/set_precision.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@ def _get_precision(backend):
1515

1616

1717
def set_precision(data_type="float32", backend="torch"):
18-
"""This function allows the user to set the default precision for floating point numbers for the given numerical backend.
18+
"""Set the default precision for floating-point numbers for the given numerical backend.
1919
Call before declaring your variables.
20-
NumPy and Tensorflow don't have global dtypes:
20+
21+
NumPy and doesn't have global dtypes:
2122
https://github.com/numpy/numpy/issues/6860
22-
https://github.com/tensorflow/tensorflow/issues/26033
23-
Therefore, torchquad sets the dtype argument for these two when initialising the integration domain.
23+
24+
Therefore, torchquad sets the dtype argument for these it when initialising the integration domain.
2425
2526
Args:
26-
data_type (string, optional): Data type to use, either "float32" or "float64". Defaults to "float32".
27-
backend (string, optional): Numerical backend for which the data type is changed. Defaults to "torch".
27+
data_type (str, optional): Data type to use, either "float32" or "float64". Defaults to "float32".
28+
backend (str, optional): Numerical backend for which the data type is changed. Defaults to "torch".
2829
"""
2930
# Backwards-compatibility: allow "float" and "double", optionally with
3031
# upper-case letters
@@ -55,14 +56,19 @@ def set_precision(data_type="float32", backend="torch"):
5556
)
5657
torch.set_default_tensor_type(tensor_dtype)
5758
elif backend == "jax":
58-
from jax.config import config
59+
from jax import config
5960

6061
config.update("jax_enable_x64", data_type == "float64")
6162
logger.info(f"JAX data type set to {data_type}")
62-
elif backend in ["numpy", "tensorflow"]:
63-
os.environ[f"TORCHQUAD_DTYPE_{backend.upper()}"] = data_type
64-
logger.info(
65-
f"Default dtype config for backend {backend} set to {_get_precision(backend)}"
66-
)
63+
elif backend == "tensorflow":
64+
import tensorflow as tf
65+
66+
# Set TensorFlow global precision
67+
tf.keras.backend.set_floatx(data_type)
68+
logger.info(f"TensorFlow default floatx set to {tf.keras.backend.floatx()}")
69+
elif backend == "numpy":
70+
# NumPy still lacks global dtype support
71+
os.environ["TORCHQUAD_DTYPE_NUMPY"] = data_type
72+
logger.info(f"NumPy default dtype set to {_get_precision('numpy')}")
6773
else:
6874
logger.error(f"Changing the data type is not supported for backend {backend}")

0 commit comments

Comments
 (0)