Skip to content

Commit f04f2b9

Browse files
authored
Model development (#18)
* add model blocks and tests * add model and tests * add torch to setup * add torch to requirements * update workflow * add data loader and tests * add opencv to requirements * add opencv to requirements and setup * update opencv version * implement attention mechanism in unet * add network visualisations * update readme * add visualisations * update requirements * devcontainer tweaks * tweak devcontainer * minor changes to setup.sh- add safe directory * add utils func to make images to npy
1 parent 887545d commit f04f2b9

File tree

13 files changed

+430
-28
lines changed

13 files changed

+430
-28
lines changed

.devcontainer/devcontainer.json

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,24 @@
88
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
99
"dockerfile": "../Dockerfile"
1010
},
11-
"features": {},
11+
"features": {
12+
"ghcr.io/devcontainers/features/common-utils:2": {
13+
"installzsh": true,
14+
"configurezshasdefaultshell": true,
15+
"installohmyzsh": true,
16+
"upgradePackages": false
17+
},
18+
"ghcr.io/devcontainers/features/docker-outside-of-docker:1": {
19+
"moby": true,
20+
"installdockerbuildx": true,
21+
"version": "20.10",
22+
"dockerdashcomposeversion": "v2"
23+
},
24+
"ghcr.io/devcontainers/features/github-cli:1": {
25+
"installDirectlyFromGitHubRelease": true,
26+
"version": "latest"
27+
}
28+
},
1229
"postCreateCommand": {
1330
"post_create": ".devcontainer/setup.sh"
1431
},
@@ -28,8 +45,8 @@
2845
}
2946
},
3047
"runArgs": [
31-
// "--runtime=nvidia",
32-
"--gpus=all"
48+
//"--runtime=nvidia",
49+
"--gpus=all"
3350
]
3451
// Features to add to the dev container. More info: https://containers.dev/features.
3552
// "features": {},

.devcontainer/setup.sh

100644100755
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#!/bin/bash
22

3+
git config --global --add safe.directory /workspaces/UnderWaterU-Net
4+
35
pip install -e .[dev]
46
pip install pytest-cov
5-
pre-commit install
7+
pre-commit install

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,7 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
.idea/
161+
162+
# Datasets
163+
data/
164+
*.npy

.pre-commit-config.yaml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v3.4.0 # Use the version you prefer
4+
hooks:
5+
- id: trailing-whitespace
6+
- id: end-of-file-fixer
7+
- id: check-yaml
8+
- id: check-added-large-files
9+
10+
- repo: https://github.com/psf/black
11+
rev: 21.9b0 # Use the version you prefer
12+
hooks:
13+
- id: black
14+
args: ['--safe']
15+
16+
- repo: https://github.com/pycqa/flake8
17+
rev: 3.9.2 # Use the version you prefer
18+
hooks:
19+
- id: flake8
20+
21+
- repo: https://github.com/pre-commit/mirrors-autopep8 # Auto formatting
22+
rev: v2.0.2
23+
hooks:
24+
- id: autopep8
25+
26+
- repo: https://github.com/pre-commit/pre-commit-hooks
27+
rev: v1.2.3
28+
hooks:
29+
- id: flake8 # Checking PEP8 that was not corrected by autopep8
30+
- id: trailing-whitespace
31+
- id: end-of-file-fixer
32+
- id: check-yaml
33+
- id: check-added-large-files
34+
35+
- repo: https://github.com/kynan/nbstripout
36+
rev: 0.6.1
37+
hooks:
38+
- id: nbstripout # Remove outputs from notebooks
39+
40+
- repo: https://github.com/nbQA-dev/nbQA # Same as above but for notebooks content
41+
rev: 1.7.0
42+
hooks:
43+
- id: nbqa-autopep8
44+
- id: nbqa-flake8
45+
args: [--ignore=F401] # Ignore unused imports as they are not fixed automatically
46+
- id: nbqa-isort

README.md

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
# UnderWaterU-Net 🌊
32

43
![UnderWaterU-Net Logo](path_to_my_logo.png)
@@ -12,3 +11,63 @@ Welcome to UnderWaterU-Net, a deep learning repository specially optimized for u
1211
- **Expandable with Submodules**: Modular design allows for easy expansion and incorporation of additional functionalities.
1312
- **Streamlined Workflow**: From raw underwater images to precise segmentations, UnderWaterU-Net makes the process seamless.
1413

14+
15+
## 🚀 Getting Started
16+
17+
### Prerequisites
18+
19+
- List any prerequisites or dependencies here.
20+
21+
### Installation
22+
23+
1. **Direct Installation**:
24+
```bash
25+
git clone [email protected]:ioannispol/UnderWaterU-Net.git
26+
```
27+
28+
2. **Advanced Setup (With Submodules)**:
29+
```bash
30+
git clone --recurse-submodules [email protected]:ioannispol/UnderWaterU-Net.git
31+
```
32+
33+
## 📖 Documentation
34+
35+
Detailed documentation can be found [here](link_to_your_documentation).
36+
<!-- Replace with a link to your documentation if you have it. -->
37+
38+
## 🤝 Contributing
39+
40+
We welcome contributions! Please see our [CONTRIBUTING.md](link_to_contributing_guide) for details.
41+
<!-- Replace with a link to your contributing guide if you have it. -->
42+
43+
## 📜 License
44+
45+
This project is licensed under the XYZ License - see the [LICENSE.md](link_to_license) for details.
46+
<!-- Replace with a link to your license file and mention the type of license you're using. -->
47+
48+
## 📬 Contact
49+
50+
For any queries, feel free to reach out to [ioannispol](mailto:[email protected]).
51+
<!-- Replace with your email or contact details. -->
52+
53+
## Attention Mechanisms in U-Net
54+
55+
The U-Net architecture has been extended to include attention gates, which allow the model to focus on specific regions of the input, enhancing its capability to segment relevant regions more accurately.
56+
57+
### AttentionGate Module
58+
59+
The AttentionGate module takes two inputs, \( g \) and \( x \), and computes the attention coefficients. These coefficients are used to weight the features in \( x \) to produce the attended features. The process can be summarized as follows:
60+
61+
1. Two 1x1 convolutions transform \( g \) and \( x \) into a compatible space.
62+
2. A non-linearity (ReLU) is applied after summing the transformed versions of \( g \) and \( x \).
63+
3. Another 1x1 convolution followed by a sigmoid activation produces the attention coefficients in the range [0, 1].
64+
4. The original \( x \) is multiplied by the attention coefficients to obtain the attended features.
65+
66+
This mechanism is particularly useful in tasks like image segmentation, enabling the network to emphasize more informative regions during training and prediction.
67+
68+
### Reference
69+
70+
The attention mechanism is inspired by the following paper:
71+
- Oktay, O., Schlemper, J., Folgoc, L. L., Lee, M., Heinrich, M., Misawa, K., ... & Glocker, B. (2018). Attention U-Net: Learning where to look for the pancreas. arXiv preprint arXiv:1804.03999.
72+
73+

notebooks/test_unet.ipynb

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import os\n",
10+
"\n",
11+
"import cv2\n",
12+
"import matplotlib.pyplot as plt\n",
13+
"import numpy as np\n",
14+
"import torch\n",
15+
"import torch.nn as nn\n",
16+
"import torchvision.datasets as datasets\n",
17+
"import torchvision.transforms as transforms\n",
18+
"\n",
19+
"from underwater_unet.model import UNet\n",
20+
"\n",
21+
"% matplotlib inline"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# Insert the UNet and AttentionUNet code here\n",
31+
"model = UNet(n_channels=1, n_classes=2) # Example for a grayscale image to be classified into 2 classes"
32+
]
33+
},
34+
{
35+
"cell_type": "code",
36+
"execution_count": null,
37+
"metadata": {},
38+
"outputs": [],
39+
"source": [
40+
"transform = transforms.Compose([\n",
41+
" transforms.Resize((256, 256)), # Resizing to fit the U-Net architecture\n",
42+
" transforms.ToTensor(),\n",
43+
"])\n",
44+
"\n",
45+
"test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)\n",
46+
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=True)"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": null,
52+
"metadata": {},
53+
"outputs": [],
54+
"source": [
55+
"def display_image_from_npy(npy_path, image_index, method=\"opencv\"):\n",
56+
" \"\"\"\n",
57+
" Load and display an image from a .npy file.\n",
58+
"\n",
59+
" Parameters:\n",
60+
" - npy_path (str): Path to the .npy file containing the images.\n",
61+
" - image_index (int): 0-based index of the image to display from the .npy file.\n",
62+
" - method (str): Method to use for displaying the image. Options are \"opencv\" or \"matplotlib\".\n",
63+
" \"\"\"\n",
64+
"\n",
65+
" # Load the dataset from the .npy file\n",
66+
" dataset = np.load(npy_path)\n",
67+
"\n",
68+
" # Check if the image_index is valid\n",
69+
" if image_index < 0 or image_index >= len(dataset):\n",
70+
" print(f\"Invalid image index. Please provide an index between 0 and {len(dataset) - 1}.\")\n",
71+
" return\n",
72+
"\n",
73+
" # Get the desired image\n",
74+
" image = dataset[image_index]\n",
75+
"\n",
76+
" if method == \"opencv\":\n",
77+
" # Display the image using OpenCV\n",
78+
" cv2.imshow(f'Image {image_index}', image)\n",
79+
" cv2.waitKey(0)\n",
80+
" cv2.destroyAllWindows()\n",
81+
" elif method == \"matplotlib\":\n",
82+
" # Display the image using Matplotlib\n",
83+
" plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))\n",
84+
" plt.title(f'Image {image_index}')\n",
85+
" plt.axis('off')\n",
86+
" plt.show()\n",
87+
" else:\n",
88+
" print(\"Invalid method. Choose 'opencv' or 'matplotlib'.\")"
89+
]
90+
},
91+
{
92+
"cell_type": "code",
93+
"execution_count": null,
94+
"metadata": {},
95+
"outputs": [],
96+
"source": [
97+
"dataset_path = '/workspaces/UnderWaterU-Net/dataset.npy'\n",
98+
"display_image_from_npy(dataset_path, 20, method=\"matplotlib\")"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"metadata": {},
105+
"outputs": [],
106+
"source": []
107+
}
108+
],
109+
"metadata": {
110+
"kernelspec": {
111+
"display_name": "Python 3",
112+
"language": "python",
113+
"name": "python3"
114+
},
115+
"language_info": {
116+
"codemirror_mode": {
117+
"name": "ipython",
118+
"version": 3
119+
},
120+
"file_extension": ".py",
121+
"mimetype": "text/x-python",
122+
"name": "python",
123+
"nbconvert_exporter": "python",
124+
"pygments_lexer": "ipython3",
125+
"version": "3.8.10"
126+
}
127+
},
128+
"nbformat": 4,
129+
"nbformat_minor": 2
130+
}

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ wandb
66
jupyterlab
77
torch >= 2.0
88
opencv-python <=4.8.0.74
9+
dowhy

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
'setuptools',
1010
'numpy',
1111
'scipy',
12+
'dowhy',
1213
'matplotlib',
1314
'pandas',
1415
'torch ~= 2.0',

train.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.utils.data import DataLoader
5+
from torchvision import transforms
6+
7+
from underwater_unet.model import UNet
8+
from utils.data_load import UnderwaterDataset
9+
10+
11+
# Hyperparameters and setup
12+
num_epochs = 10
13+
learning_rate = 0.001
14+
batch_size = 16
15+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16+
17+
# Load dataset and dataloader
18+
train_dataset = UnderwaterDataset(images_dir='data/images', mask_dir='data/masks', resize_to=None)
19+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
20+
21+
# Load dataset and dataloader
22+
transform = transforms.Compose([transforms.ToTensor()])
23+
train_dataset = UnderwaterDataset(root_dir='data', transform=transform)
24+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
25+
26+
# Initialize model, loss and optimizer
27+
model = UNet(n_channels=3, n_classes=1).to(device)
28+
criterion = nn.BCEWithLogitsLoss()
29+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
30+
31+
# Training loop
32+
for epoch in range(num_epochs):
33+
model.train()
34+
for batch in train_loader:
35+
images = batch['image'].to(device)
36+
masks = batch['mask'].to(device)
37+
38+
# Forward pass
39+
outputs = model(images)
40+
loss = criterion(outputs, masks)
41+
42+
# Backward pass and optimization
43+
optimizer.zero_grad()
44+
loss.backward()
45+
optimizer.step()
46+
47+
print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")
48+
49+
# Save the model
50+
torch.save(model.state_dict(), f"experiment/model_epoch_{epoch + 1}.pth")
51+
52+
print("Training completed.")

0 commit comments

Comments
 (0)