Skip to content

Commit 4a9882f

Browse files
authored
Add model blocks and tests (#7)
* add model blocks and tests * add model and tests * add torch to setup * add torch to requirements * update workflow
1 parent 754b7be commit 4a9882f

File tree

14 files changed

+438
-16
lines changed

14 files changed

+438
-16
lines changed

.devcontainer/devcontainer.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
"dockerfile": "../Dockerfile"
1010
},
1111
"features": {},
12-
"postCreateCommand": "echo 'Development container for UWUNet is ready!'",
12+
"postCreateCommand": {
13+
"post_create": ".devcontainer/setup.sh"
14+
},
1315

1416
"containerEnv": {
1517
"PYTHONPATH": "${containerWorkspaceFolder}"

.devcontainer/setup.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/bin/bash
2+
3+
pip install -e .[dev]
4+
pip install pytest-cov
5+
pre-commit install

.github/workflows/python-package.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ jobs:
1414

1515
steps:
1616
- uses: actions/checkout@v2
17+
- name: Set PYTHONPATH
18+
run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV
1719
- name: Set up Python 3.8
1820
uses: actions/setup-python@v2
1921
with:

.gitignore

Lines changed: 159 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,160 @@
1+
Byte-compiled / optimized / DLL files
12
__pycache__/
2-
*.pyc
3-
*.pyo
4-
*.pyd
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/#use-with-ide
110+
.pdm.toml
111+
112+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113+
__pypackages__/
114+
115+
# Celery stuff
116+
celerybeat-schedule
117+
celerybeat.pid
118+
119+
# SageMath parsed files
120+
*.sage.py
121+
122+
# Environments
123+
.env
124+
.venv
125+
env/
126+
venv/
127+
ENV/
128+
env.bak/
129+
venv.bak/
130+
131+
# Spyder project settings
132+
.spyderproject
133+
.spyproject
134+
135+
# Rope project settings
136+
.ropeproject
137+
138+
# mkdocs documentation
139+
/site
140+
141+
# mypy
142+
.mypy_cache/
143+
.dmypy.json
144+
dmypy.json
145+
146+
# Pyre type checker
147+
.pyre/
148+
149+
# pytype static type analyzer
150+
.pytype/
151+
152+
# Cython debug symbols
153+
cython_debug/
154+
155+
# PyCharm
156+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158+
# and can be added to the global gitignore or merged into this file. For a more nuclear
159+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160+
.idea/

UnderWaterU-Net/__init__.py

Whitespace-only changes.

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ numpy
33
Pillow
44
tqdm
55
wandb
6-
jupyterlab
6+
jupyterlab
7+
torch >= 2.0

setup.cfg

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[flake8]
2+
max_line_length = 120
3+
per-file-ignores = __init__.py:F401
4+
5+
[autflake8]
6+
in-place = true
7+
remove-all-unused-imports = true
8+
9+
# [tool:pytest]
10+
# addopts = --cov
11+
12+
[coverage:run]
13+
branch = False
14+
15+
[coverage:report]
16+
show_missing = True
17+
skip_covered = True

setup.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,28 @@
1-
21
from setuptools import setup, find_packages
32

43
setup(
5-
name='UnderWaterU-Net',
6-
version='0.1',
7-
packages=find_packages(),
4+
name='underwater_unet',
5+
version='0.0.1',
6+
install_requires=[
7+
'importlib-metadata; python_version >= "3.9"',
8+
'jupyterlab',
9+
'setuptools',
10+
'numpy',
11+
'scipy',
12+
'matplotlib',
13+
'pandas',
14+
'torch ~= 2.0',
15+
],
16+
extras_require={
17+
'dev': [
18+
'pytest',
19+
'pre-commit',
20+
'pytest-cov',
21+
'nbmake'
22+
]
23+
},
24+
packages=find_packages(
25+
include=['underwater_unet', 'underwater_unet*'],
26+
exclude=['tests', 'tests.*', 'notebooks']
27+
),
828
)

tests/test_dummy.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/test_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
3+
from underwater_unet.model import UNet
4+
5+
6+
def test_UNet():
7+
# Create a UNet instance
8+
n_channels = 3
9+
n_classes = 2
10+
model = UNet(n_channels=n_channels, n_classes=n_classes)
11+
12+
# Generate a random tensor of size [batch_size, channels, height, width]
13+
x = torch.randn(1, n_channels, 128, 128, requires_grad=True)
14+
15+
# Forward pass
16+
y = model(x)
17+
18+
# Check if the tensor has been transformed correctly
19+
assert y.shape == (1, n_classes, 128, 128), f"Expected shape (1, {n_classes}, 128, 128) but got {y.shape}"
20+
21+
# Ensure the output is of the correct type
22+
assert isinstance(y, torch.Tensor), f"Expected output type torch.Tensor but got {type(y)}"
23+
24+
# Ensure the values are within a reasonable range (sanity check, not strictly required)
25+
assert y.max() <= 10 and y.min() >= -10, "Output values are out of a reasonable range"

0 commit comments

Comments
 (0)