Skip to content

Commit 0bbdd3e

Browse files
committed
Merge branch 'master' into neuralsde
2 parents 880e855 + c708c4c commit 0bbdd3e

File tree

15 files changed

+172
-63
lines changed

15 files changed

+172
-63
lines changed

.github/workflows/codestyle.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ jobs:
1616
runs-on: ubuntu-latest
1717

1818
steps:
19-
- uses: actions/checkout@v1
19+
- uses: actions/checkout@v3
2020
- name: Set up Python 3.8
21-
uses: actions/setup-python@v1
21+
uses: actions/setup-python@v4
2222
with:
2323
python-version: 3.8
2424
- name: Install dependencies

.github/workflows/os-coverage.yml

+59-25
Original file line numberDiff line numberDiff line change
@@ -7,38 +7,72 @@ jobs:
77
strategy:
88
fail-fast: true
99
max-parallel: 15
10-
1110
matrix:
12-
os: [ubuntu-18.04, ubuntu-20.04, ubuntu-22.04, macos-latest]
13-
python-version: [3.7, 3.8, 3.9]
11+
os: [ubuntu-latest, macos-latest, windows-latest]
12+
python-version: ["3.8", "3.9", "3.10", "3.11"]
13+
torch-version: ["1.8.1", "1.9.1", "1.10.0", "1.11.0", "1.12.0", "1.13.1", "2.0.0"]
1414
exclude:
15-
# temporary exclusion until py3.8 deps work on Windows
16-
- python-version: 3.7
17-
os: windows-2019
18-
- python-version: 3.8
19-
os: windows-2019
15+
# python >= 3.10 does not support pytorch < 1.11.0
16+
- torch-version: "1.8.1"
17+
python-version: "3.10"
18+
- torch-version: "1.9.1"
19+
python-version: "3.10"
20+
- torch-version: "1.10.0"
21+
python-version: "3.10"
22+
# python >= 3.11 does not support pytorch < 1.13.0
23+
- torch-version: "1.8.1"
24+
python-version: "3.11"
25+
- torch-version: "1.9.1"
26+
python-version: "3.11"
27+
- torch-version: "1.10.0"
28+
python-version: "3.11"
29+
- torch-version: "1.11.0"
30+
python-version: "3.11"
31+
- torch-version: "1.12.0"
32+
python-version: "3.11"
33+
- torch-version: "1.13.1"
34+
python-version: "3.11"
2035

36+
defaults:
37+
run:
38+
shell: bash
2139
steps:
22-
- uses: actions/checkout@v2
40+
- name: Check out repository
41+
uses: actions/checkout@v3
42+
2343
- name: Set up Python ${{ matrix.python-version }}
24-
uses: actions/setup-python@v2
44+
uses: actions/setup-python@v4
2545
with:
2646
python-version: ${{ matrix.python-version }}
27-
28-
- name: Install dependencies
47+
48+
- name: Install Poetry
49+
uses: snok/install-poetry@v1
50+
with:
51+
virtualenvs-create: true
52+
virtualenvs-in-project: true
53+
54+
- name: Load cached venv
55+
id: cached-pip-wheels
56+
uses: actions/cache@v3
57+
with:
58+
path: ~/.cache
59+
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.torch-version }}-${{ hashFiles('**/poetry.lock') }}
60+
61+
- name: Install dependencies # hack for 🐛: don't let poetry try installing Torch https://github.com/pytorch/pytorch/issues/88049
2962
run: |
30-
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
31-
source $HOME/.poetry/env
32-
python -m pip install --upgrade --user pip
33-
pip install pytest pytest-cov
34-
poetry lock
35-
poetry build
36-
poetry install
37-
poetry run pip install 'setuptools==59.5.0'
38-
# pinning setuptools is a temporary fix to: pytorch-lightning 1.5.10 requires setuptools==59.5.0
39-
# supposedly poetry allows pinning ver. of setuptools in pyproject.ml files but it is not working atm https://github.com/python-poetry/poetry/issues/4511
40-
41-
- name: Run tests
63+
pip install pytest pytest-cov papermill poethepoet>=0.10.0
64+
pip install torch==${{ matrix.torch-version }} pytorch-lightning scikit-learn torchsde torchcde>=0.2.3 scipy matplotlib ipykernel ipywidgets
65+
poetry install --only-root
66+
poetry run pip install setuptools
67+
68+
- name: List dependencies
69+
run: |
70+
pip list
71+
72+
- name: Run pytest checks
4273
run: |
43-
source $HOME/.poetry/env
74+
source $VENV
4475
poetry run coverage run --source=torchdyn -m pytest
76+
77+
- name: Report coverage
78+
uses: codecov/[email protected]

.github/workflows/publish.yaml

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,24 @@ on:
55
types:
66
- created
77
- edited
8-
8+
99
jobs:
1010
build:
1111
runs-on: ubuntu-20.04
1212
steps:
13-
- uses: actions/checkout@v2
14-
- uses: actions/setup-python@v2
13+
- uses: actions/checkout@v3
14+
- uses: actions/setup-python@v3
1515
with:
1616
python-version: 3.8
1717

1818
- name: Build
1919
run: |
20-
curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -
21-
source $HOME/.poetry/env
20+
curl -sSL https://install.python-poetry.org | python3 -
2221
poetry lock
2322
poetry build
2423
2524
- name: Publish distribution 📦 to PyPI
2625
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
2726
run: |
28-
source $HOME/.poetry/env
2927
poetry config pypi-token.pypi ${{ secrets.pypi_token }}
3028
poetry publish

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Torchdyn is a PyTorch library dedicated to **numerical deep learning**: differen
1313
[![Slack](https://img.shields.io/badge/slack-chat-blue.svg?logo=slack)](https://join.slack.com/t/diffeqml/shared_invite/zt-trwgahq8-zgDqFmwS2gHYX6hsRvwDvg)
1414
[![codecov](https://codecov.io/gh/DiffEqML/torchdyn/branch/master/graph/badge.svg)](https://codecov.io/gh/DiffEqML/torchdyn)
1515
[![Docs](https://img.shields.io/badge/docs-passing-green.svg?)](https://torchdyn.readthedocs.io/)
16-
[![python_sup](https://img.shields.io/badge/python-3.7+-black.svg?)](https://www.python.org/downloads/release/python-370/)
16+
[![python_sup](https://img.shields.io/badge/python-3.8+-black.svg?)](https://www.python.org/downloads/release/python-370/)
1717

1818
</div>
1919

pyproject.toml

+13-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "torchdyn"
3-
version = "1.0.3"
3+
version = "1.0.6"
44
license = "Apache License, Version 2.0"
55
description = "A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods."
66
authors = ["Michael Poli", "Stefano Massaroli", "DiffEqML"]
@@ -9,11 +9,11 @@ packages = [
99
]
1010

1111
[tool.poetry.dependencies]
12-
python = "^3.7"
13-
torch = "^1.8.1"
12+
python = "^3.8"
13+
torch = ">=1.8.1"
1414
torchsde="*"
1515
torchcde="^0.2.3"
16-
sklearn = "*"
16+
scikit-learn = "*"
1717
pytorch-lightning = "*"
1818
torchvision = "*"
1919
scipy = "*"
@@ -38,6 +38,15 @@ build-backend = "poetry.masonry.api"
3838
requires = ["poetry", "wheel", "setuptools-cpp"]
3939

4040
[tool.pytest.ini_options]
41+
log_cli = true
42+
log_cli_level = "CRITICAL"
43+
log_cli_format = "%(message)s"
44+
45+
log_file = "pytest.log"
46+
log_file_level = "DEBUG"
47+
log_file_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
48+
log_file_date_format = "%Y-%m-%d %H:%M:%S"
49+
4150
filterwarnings = [
4251
"ignore:Call to deprecated create function FieldDescriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9
4352
"ignore:Call to deprecated create function Descriptor", # pytorch lightning needs tensorboard which has a conflict with python 3.9

setup.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616
setup(
1717
name="torchdyn",
18-
version="1.0.3",
18+
version="1.0.6",
1919
author="Michael Poli and Stefano Massaroli",
2020
description="PyTorch package for all things neural differential equations.",
2121
url="https://github.com/DiffEqML/torchdyn",
2222
install_requires=[
23-
"torch>=1.6.0",
23+
"torch>=1.8.1",
2424
"pytorch-lightning>=0.8.4",
2525
"matplotlib",
2626
"scikit-learn",
@@ -31,4 +31,5 @@
3131
"Programming Language :: Python :: 3",
3232
"License :: OSI Approved :: Apache Software License",
3333
],
34+
packages=["torchdyn"],
3435
)

test/models/test_ode.py

+59-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
from packaging.version import parse
1314
import pytest
1415
import torch
1516
import torch.nn as nn
@@ -19,7 +20,7 @@
1920
from torchdyn.datasets import ToyDataset
2021
from torchdyn.core import NeuralODE
2122
from torchdyn.nn import GalLinear, GalConv2d, DepthCat, Augmenter, DataControl
22-
from torchdyn.numerics import odeint, Euler
23+
from torchdyn.numerics import odeint, odeint_mshooting, Lorenz, Euler
2324

2425
from functools import partial
2526
import copy
@@ -258,4 +259,60 @@ def forward(self, t, x, u, v, z, args={}):
258259
t_eval, sol2 = odeprob(x0, t_span=torch.linspace(0, 5, 10))
259260

260261
assert (sol1==sol2).all()
261-
grad(sol2.sum(), x0)
262+
grad(sol2.sum(), x0)
263+
264+
265+
@pytest.mark.skipif(parse(torch.__version__) < parse("1.11.0"),
266+
reason="adjoint support added in torch 1.11.0")
267+
def test_complex_ode():
268+
"""Test odeint for complex numbers with a simple complex-valued ODE, corresponding
269+
to Rabi oscillations of quantum two-level system."""
270+
class Rabi(nn.Module):
271+
def __init__(self, omega):
272+
super().__init__()
273+
self.sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
274+
self.omega = omega
275+
return
276+
def forward(self, t, x):
277+
dx = -1.0j * self.omega * self.sx @ x
278+
dx += dx.adjoint()
279+
return dx
280+
281+
# Odeint parameters
282+
omega = torch.randn(1)
283+
rabi = Rabi(omega)
284+
tspan = torch.linspace(0., 2., 10)
285+
286+
# Random initial state
287+
x0 = torch.rand(2, 2, dtype=torch.complex128)
288+
x0 = 0.5 * (x0 + x0.adjoint()) / torch.real(x0.trace())
289+
# Solve the ODE problem
290+
t_eval, sol = odeint(f=rabi, x=x0, t_span=tspan, solver="dopri5", atol=1e-8, rtol=1e-6)
291+
292+
# Expected solution
293+
sx = torch.tensor([[0, 1], [1, 0]], dtype=torch.complex128)
294+
si = torch.tensor([[1, 0], [0, 1]], dtype=torch.complex128)
295+
U_t = torch.cos(omega * t_eval)[:, None, None] * si
296+
U_t += -1j * torch.sin(omega * t_eval)[:, None, None] * sx
297+
sol_exp = U_t @ x0 @ U_t.adjoint()
298+
299+
# Check result
300+
assert torch.allclose(sol, sol_exp, rtol=1e-5, atol=1e-5)
301+
302+
303+
@pytest.mark.parametrize('solver', ['mszero'])
304+
def test_odeint_mshooting(solver):
305+
x0 = torch.randn(8, 3) + 15
306+
t_span = torch.linspace(0, 3, 10)
307+
sys = Lorenz()
308+
309+
odeint_mshooting(sys, x0, t_span, solver=solver, fine_steps=2, maxiter=4)
310+
311+
312+
@pytest.mark.parametrize('solver', ['euler', 'rk4', 'dopri5'])
313+
def test_odeint(solver):
314+
x0 = torch.randn(8, 3) + 15
315+
t_span = torch.linspace(0., 2., 10)
316+
sys = Lorenz()
317+
318+
odeint(sys, x0, t_span, solver=solver)

test/test_sensitivity.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13+
# Test adjoint and perform a rough benchmarking of wall-clock time
14+
1315
import time
1416
from copy import deepcopy
17+
import logging
1518

1619
import pytest
1720
import torch
@@ -27,9 +30,10 @@
2730
batch_size = 128
2831
torch.manual_seed(1415112413244349)
2932

30-
3133
t_span = torch.linspace(0, 1, 100)
3234

35+
logger = logging.getLogger("out")
36+
3337

3438
# TODO(numerics): log wall-clock times and other torch.grad tests
3539
# TODO(bug): `tsit5` + `adjoint` peak error
@@ -38,17 +42,21 @@
3842
@pytest.mark.parametrize('stiffness', [0.1, 0.5])
3943
@pytest.mark.parametrize('interpolator', [None])
4044
def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):
45+
4146
f = VanDerPol(stiffness)
4247
x = torch.randn(1024, 2, requires_grad=True)
43-
t0 = time.time()
48+
4449
prob = ODEProblem(f, sensitivity=sensitivity, interpolator=interpolator, solver=solver, atol=1e-4, rtol=1e-4, atol_adjoint=1e-4, rtol_adjoint=1e-4)
50+
t0 = time.time()
4551
t_eval, sol_torchdyn = prob.odeint(x, t_span)
4652
t_end1 = time.time() - t0
4753

4854
t0 = time.time()
4955
sol_torchdiffeq = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-4, rtol=1e-4)
5056
t_end2 = time.time() - t0
5157

58+
logger.info(f"Fwd times: {t_end1:.3f}, {t_end2:.3f}")
59+
5260
true_sol = torchdiffeq.odeint_adjoint(f, x, t_span, method='dopri5', atol=1e-9, rtol=1e-9)
5361

5462
t0 = time.time()
@@ -59,6 +67,8 @@ def test_odeint_adjoint(sensitivity, solver, interpolator, stiffness):
5967
grad2 = torch.autograd.grad(sol_torchdiffeq[-1].sum(), x)[0]
6068
t_end2 = time.time() - t0
6169

70+
logger.info(f"Bwd times: {t_end1:.3f}, {t_end2:.3f}")
71+
6272
grad_true = torch.autograd.grad(true_sol[-1].sum(), x)[0]
6373

6474
err1 = (grad1-grad_true).abs().sum(1)

torchdyn/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
# See the License for the specific language governing permissions and
1111
# limitations under the License.
1212

13-
__version__ = '1.0'
13+
__version__ = '1.0.6'
1414
__author__ = 'Michael Poli, Stefano Massaroli et al.'
1515

1616
from torch import Tensor
1717
from typing import Tuple
1818

1919
TTuple = Tuple[Tensor, Tensor]
20-

torchdyn/numerics/odeint.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Functional API of ODE integration routines, with specialized functions for different options
1515
`odeint` and `odeint_mshooting` prepare and redirect to more specialized routines, detected automatically.
1616
"""
17-
from inspect import getargspec
1817
from typing import List, Tuple, Union, Callable, Dict, Iterable
1918
from warnings import warn
2019

@@ -65,11 +64,6 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
6564
x, t_span = solver.sync_device_dtype(x, t_span)
6665
stepping_class = solver.stepping_class
6766

68-
# instantiate save_at tensor
69-
if len(save_at) == 0: save_at = t_span
70-
if not isinstance(save_at, torch.Tensor):
71-
save_at = torch.tensor(save_at)
72-
7367
# instantiate the interpolator similar to the solver steps above
7468
if isinstance(solver, Tsitouras45):
7569
if verbose: warn("Running interpolation not yet implemented for `tsit5`")
@@ -87,6 +81,7 @@ def odeint(f:Callable, x:Tensor, t_span:Union[List, Tensor], solver:Union[str, n
8781
if stepping_class == 'fixed':
8882
if atol != odeint.__defaults__[0] or rtol != odeint.__defaults__[1]:
8983
warn("Setting tolerances has no effect on fixed-step methods")
84+
# instantiate save_at tensor
9085
return _fixed_odeint(f_, x, t_span, solver, save_at=save_at, args=args)
9186
elif stepping_class == 'adaptive':
9287
t = t_span[0]
@@ -415,6 +410,10 @@ def _adaptive_odeint(f, k1, x, dt, t_span, solver, atol=1e-4, rtol=1e-4, args=No
415410

416411
def _fixed_odeint(f, x, t_span, solver, save_at=(), args={}):
417412
"""Solves IVPs with same `t_span`, using fixed-step methods"""
413+
if len(save_at) == 0: save_at = t_span
414+
if not isinstance(save_at, torch.Tensor):
415+
save_at = torch.tensor(save_at)
416+
418417
assert all(torch.isclose(t, save_at).sum() == 1 for t in save_at),\
419418
"each element of save_at [torch.Tensor] must be contained in t_span [torch.Tensor] once and only once"
420419

0 commit comments

Comments
 (0)