Skip to content

Commit 00f1e3c

Browse files
committed
Updates CIFAR-10 example to use CNN instead of MLP.
1 parent 0c7b986 commit 00f1e3c

7 files changed

+105
-58
lines changed

conslearn/buildConstrainedNetwork.m

+7-7
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,8 @@
1111
% The network includes either a featureInputLayer or an imageInputLayer,
1212
% depending on INPUTSIZE:
1313
%
14-
% - If INPUTSIZE is a scalar, then the network has a featureInputLayer.
15-
%
16-
% - If INPUTSIZE is a vector with three elements, then the network has an
14+
% - If INPUTSIZE is a scalar, then the network has a featureInputLayer. -
15+
% If INPUTSIZE is a vector with three elements, then the network has an
1716
% imageInputLayer.
1817
%
1918
% NUMHIDDENUNITS is a vector of integers that corresponds to the sizes
@@ -30,7 +29,8 @@
3029
% ConvexNonDecreasingActivation - Convex, non-decreasing
3130
% ("fully-convex") activation functions.
3231
% ("partially-convex") The options are "softplus" or
33-
% "relu". The default is "softplus".
32+
% "relu".
33+
% The default is "softplus".
3434
% Activation - Network activation function.
3535
% ("partially-convex") The options are "tanh", "relu" or
3636
% "fullsort". The default is "tanh".
@@ -80,9 +80,9 @@
8080
% "fullsort". The default is
8181
% "fullsort".
8282
% UpperBoundLipschitzConstant - Upper bound on the Lipschitz
83-
% constant for the network, as a
84-
% positive real number. The default
85-
% value is 1.
83+
% constant
84+
% for the network, as a positive real
85+
% number. The default value is 1.
8686
% pNorm - p-norm value for measuring
8787
% distance with respect to the
8888
% Lipschitz continuity definition.

conslearn/trainConstrainedNetwork.m

+59-8
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@
3131
% iteration, specified as: "mse", "mae", or
3232
% "crossentropy".
3333
% The default is "mse".
34+
% L2Regularization - Factor for L2 regularization (weight decay).
35+
% The default is 0.
36+
% ValidationData - Data to use for validation during training,
37+
% specified as a minibatchqueue object.
38+
% ValidationFrequency - Frequency of validation in number of
39+
% iterations. The default is 50.
3440
% TrainingMonitor - Flag to display the training progress monitor
3541
% showing the training data loss.
3642
% The default is true.
@@ -73,6 +79,9 @@
7379
trainingOptions.LossMetric {...
7480
mustBeTextScalar, ...
7581
mustBeMember(trainingOptions.LossMetric,["mse","mae","crossentropy"])} = "mse";
82+
trainingOptions.L2Regularization (1,1) {mustBeNumeric, mustBeNonnegative} = 0
83+
trainingOptions.ValidationData minibatchqueue {mustBeScalarOrEmpty} = minibatchqueue.empty
84+
trainingOptions.ValidationFrequency (1,1) {mustBeNumeric, mustBePositive, mustBeInteger} = 50
7685
trainingOptions.TrainingMonitor (1,1) logical = true;
7786
trainingOptions.TrainingMonitorLogScale (1,1) logical = true;
7887
trainingOptions.ShuffleMinibatches (1,1) logical = false;
@@ -84,26 +93,39 @@
8493
% Set up the training progress monitor
8594
if trainingOptions.TrainingMonitor
8695
monitor = trainingProgressMonitor;
96+
97+
% Track progress information
8798
monitor.Info = ["LearningRate","Epoch","Iteration"];
88-
monitor.Metrics = "TrainingLoss";
99+
100+
% Plot the training and validation metrics on the same plot
101+
monitor.Metrics = ["TrainingLoss", "ValidationLoss"];
102+
groupSubPlot(monitor, "Loss", ["TrainingLoss", "ValidationLoss"]);
103+
89104
% Apply loss log scale
90105
if trainingOptions.TrainingMonitorLogScale
91-
yscale(monitor,"TrainingLoss","log");
106+
yscale(monitor,"Loss","log");
92107
end
108+
93109
% Specify the horizontal axis label for the training plot.
94110
monitor.XLabel = "Iteration";
111+
95112
% Start the monitor
96113
monitor.Status = "Running";
97114
stopButton = @() ~monitor.Stop;
98115
else
116+
% Let training run without a monitor by setting stop to false
99117
stopButton = @() 1;
100118
end
119+
101120
% Prepare the generic hyperparameters
102121
maxEpochs = trainingOptions.MaxEpochs;
103122
initialLearnRate = trainingOptions.InitialLearnRate;
104123
decay = trainingOptions.Decay;
105124
metric = trainingOptions.LossMetric;
106125
shuffleMinibatches = trainingOptions.ShuffleMinibatches;
126+
l2Regularization = trainingOptions.L2Regularization;
127+
validationData = trainingOptions.ValidationData;
128+
validationFrequency = trainingOptions.ValidationFrequency;
107129

108130
% Specify ADAM options
109131
avgG = [];
@@ -147,7 +169,7 @@
147169

148170
% Evaluate the model gradients, and loss using dlfeval and the
149171
% modelLoss function and update the network state.
150-
[lossTrain,gradients,state] = dlfeval(@iModelLoss,net,X,T,metric);
172+
[lossTrain,gradients,state] = dlfeval(dlaccelerate(@iModelLoss),net,X,T,metric,l2Regularization);
151173
net.State = state;
152174

153175
% Gradient Update
@@ -162,10 +184,33 @@
162184
LearningRate=learnRate, ...
163185
Epoch=string(epoch) + " of " + string(maxEpochs), ...
164186
Iteration=string(iteration));
187+
165188
recordMetrics(monitor,iteration, ...
166189
TrainingLoss=lossTrain);
190+
167191
monitor.Progress = 100*epoch/maxEpochs;
168192
end
193+
194+
% Record validation loss, if requested
195+
if ~isempty(validationData)
196+
if (iteration == 1) || (mod(iteration, validationFrequency) == 0)
197+
198+
% Reset the validation data
199+
if ~hasdata(validationData)
200+
reset(validationData);
201+
end
202+
203+
% Compute the validation loss
204+
[X, T] = next(validationData);
205+
lossValidation = iModelLoss(net, X, T, metric, l2Regularization);
206+
207+
% Update the training monitor
208+
if trainingOptions.TrainingMonitor
209+
recordMetrics(monitor,iteration, ...
210+
ValidationLoss=lossValidation);
211+
end
212+
end
213+
end
169214
end
170215
end
171216

@@ -181,23 +226,29 @@
181226
end
182227

183228
%% Helpers
184-
function [loss,gradients,state] = iModelLoss(net,X,T,metric)
229+
function [loss,gradients,state] = iModelLoss(net,X,T,metric,l2Regularization)
185230

186231
% Make a forward pass
187-
[Y,state] = forward(net,X);
232+
[Y, state] = forward(net,X);
188233

189234
% Compute the loss
190235
switch metric
191236
case "mse"
192237
loss = mse(Y,T);
193238
case "mae"
194-
loss = mean(abs(Y-T));
239+
loss = mean(abs(Y-T), 'all');
195240
case "crossentropy"
196241
loss = crossentropy(softmax(Y),T);
197242
end
198243

199-
% Compute the gradient of the loss with respect to the learnabless
200-
gradients = dlgradient(loss,net.Learnables);
244+
if nargout > 1
245+
% Compute the gradient of the loss with respect to the learnables
246+
gradients = dlgradient(loss,net.Learnables);
247+
248+
% Apply L2 regularization
249+
idxWeights = net.Learnables.Parameter == "Weights";
250+
gradients(idxWeights,:) = dlupdate(@(g,w) g + l2Regularization*w, gradients(idxWeights, :), net.Learnables(idxWeights, :));
251+
end
201252
end
202253

203254
function proximalOp = iSetupProximalOperator(constraint,trainingOptions)

examples/convex/classificationCIFAR10/TrainICNNOnCIFAR10Example.md

+39-43
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

22
# <span style="color:rgb(213,80,0)">Train Fully Convex Neural Network for Image Classification</span>
33

4-
This example shows how to create a fully input convex neural network and train it on CIFAR\-10 data. This example uses fully connected based convex networks, rather than the more typical convolutional networks, proven to give higher accuracy on the training and test data set. The aim of this example is to demonstrate the expressive capabilities convex constrained networks have by classifying natural images and demonstrating high accuracies on the training set. Further discussion on the expressive capabilities of convex networks for tasks including image classification can be found in \[1\].
4+
This example shows how to create a fully input convex convolutional neural network and train it on CIFAR\-10 data \[1\].
55

66
# Prepare Data
77

8-
Download the CIFAR\-10 data set \[1\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.
8+
Download the CIFAR\-10 data set \[2\]. The data set contains 60,000 images. Each image is 32\-by\-32 pixels in size and has three color channels (RGB). The size of the data set is 175 MB. Depending on your internet connection, the download process can take time.
99

1010
```matlab
1111
datadir = ".";
@@ -18,16 +18,6 @@ Load the CIFAR\-10 training and test images as 4\-D arrays. The training set con
1818
[XTrain,TTrain,XTest,TTest] = loadCIFARData(datadir);
1919
```
2020

21-
For illustration in this example, subsample this data set evenly in each class. You can increase the number of samples by moving the slider to smaller values.
22-
23-
```matlab
24-
subSampleFrequency = 10;
25-
XTrain = XTrain(:,:,:,1:subSampleFrequency:end);
26-
XTest = XTest(:,:,:,1:subSampleFrequency:end);
27-
TTrain = TTrain(1:subSampleFrequency:end);
28-
TTest = TTest(1:subSampleFrequency:end);
29-
```
30-
3121
You can display a random sample of the training images using the following code.
3222

3323
<pre>
@@ -37,42 +27,45 @@ im = imtile(XTrain(:,:,:,idx),ThumbnailSize=[96,96]);
3727
imshow(im)
3828
</pre>
3929

40-
# Define FICNN Network Architecture
30+
# Define FICCNN Network Architecture
4131

42-
Use the <samp>buildConstrainedNetwork</samp> function to create a fully input convex neural network suitable for this data set.
32+
Use the <samp>buildConvexCNN</samp> function to create a fully input convex convolutional neural network suitable for this data set.
4333

44-
- The CIFAR\-10 images are 32\-by\-32 pixels. Therefore, create a fully convex network specifying the <samp>inputSize=[32 32 3]</samp>.
45-
- Specify a vector a hidden unit sizes of decreasing value in <samp>numHiddenUnits</samp>. The final number of outputs of the network must be equal to the number of classes, which in this example is 10.
34+
- The CIFAR\-10 images are 32\-by\-32 pixels, and belong to one of ten classes. Therefore, create a fully convex network specifying the <samp>inputSize=[32 32 3]</samp> and the <samp>numClasses=10</samp>.
35+
- For each convolutional layer, specify the filter size in <samp>filterSize</samp>, the number of filters in <samp>numFilters</samp>, and the stride size in <samp>stride</samp>.
4636
```matlab
4737
inputSize = [32 32 3];
48-
numHiddenUnits = [512 128 32 10];
38+
numClasses = 10;
39+
filterSize = [3; 3; 3; 3; 3; 1; 1];
40+
numFilters = [96; 96; 192; 192; 192; 192; 10];
41+
stride = [1; 2; 1; 2; 1; 1; 1];
4942
```
5043

51-
Seed the network initialization for reproducibility.
44+
Seed the network initialization for reproducibility.
5245

5346
```matlab
5447
rng(0);
55-
ficnnet = buildConstrainedNetwork("fully-convex",inputSize,numHiddenUnits)
48+
ficnnet = buildConvexCNN(inputSize, numClasses, filterSize, numFilters, Stride=stride)
5649
```
5750

5851
```matlabTextOutput
5952
ficnnet =
6053
dlnetwork with properties:
6154
62-
Layers: [15x1 nnet.cnn.layer.Layer]
63-
Connections: [17x2 table]
64-
Learnables: [14x3 table]
65-
State: [0x3 table]
66-
InputNames: {'image_input'}
67-
OutputNames: {'add_4'}
55+
Layers: [24x1 nnet.cnn.layer.Layer]
56+
Connections: [23x2 table]
57+
Learnables: [30x3 table]
58+
State: [14x3 table]
59+
InputNames: {'input'}
60+
OutputNames: {'fc_+_end'}
6861
Initialized: 1
6962
7063
View summary with summary.
7164
7265
```
7366

7467
```matlab
75-
plot(ficnnet)
68+
plot(ficnnet);
7669
```
7770

7871
<figure>
@@ -83,14 +76,15 @@ plot(ficnnet)
8376

8477
# Specify Training Options
8578

86-
Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example <samp>numEpochs=8000</samp>, which could take several hours.
79+
Train for a specified number of epochs with a mini\-batch size of 256. To attain high training accuracy, you may need to train for a larger number of epochs, for example <samp>numEpochs=400</samp>, which could take several hours.
8780

8881
```matlab
89-
numEpochs = 8000;
82+
numEpochs = 400;
9083
miniBatchSize = 256;
91-
initialLearnRate = 0.1;
92-
decay = 0.005;
84+
initialLearnRate = 0.0025;
85+
decay = eps;
9386
lossMetric = "crossentropy";
87+
l2Regularization = 1e-4;
9488
```
9589

9690
Create a <samp>minibatchqueue</samp> object that processes and manages mini\-batches of images during training. For each mini\-batch:
@@ -103,6 +97,7 @@ Create a <samp>minibatchqueue</samp> object that processes and manages mini\-bat
10397
xds = arrayDatastore(XTrain,IterationDimension=4);
10498
tds = arrayDatastore(TTrain,IterationDimension=1);
10599
cds = combine(xds,tds);
100+
106101
mbqTrain = minibatchqueue(cds,...
107102
MiniBatchSize=miniBatchSize,...
108103
MiniBatchFcn=@preprocessMiniBatch,...
@@ -160,7 +155,7 @@ disp("Training accuracy: " + (1-trainError)*100 + "%")
160155
```
161156

162157
```matlabTextOutput
163-
Training accuracy: 90.4848%
158+
Training accuracy: 70.2123%
164159
```
165160

166161
Compute the accuracy on the test set.
@@ -173,7 +168,7 @@ disp("Test accuracy: " + (1-testError)*100 + "%")
173168
```
174169

175170
```matlabTextOutput
176-
Test accuracy: 27.4554%
171+
Test accuracy: 66.266%
177172
```
178173

179174
The networks output has been constrained to be convex in every pixel in every colour. Even with this level of restriction, the network is able to fit reasonably well to the training data. You can see poor accuracy on the test data set but, as discussed at the start of the example, it is not anticipated that such a fully input convex network comprising of fully connected operations should generalize well to natural image classification.
@@ -197,14 +192,14 @@ cm.RowSummary = "row-normalized";
197192

198193
To summarise, the fully input convex network is able to fit to the training data set, which is labelled natural images. The training can take a considerable amount of time owing to the weight projection to the constrained set after each gradient update, which slows down training convergence. Nevertheless, this example illustrates the flexibility and expressivity convex neural networks have to correctly classifying natural images.
199194

200-
# Supporting Functions
201-
## Mini Batch Preprocessing Function
195+
# Supporting Functions
196+
## Mini\-Batch Preprocessing Function
202197

203-
The <samp>preprocessMiniBatch</samp> function preprocesses a mini\-batch of predictors and labels using the following steps:
198+
The <samp>preprocessMiniBatch</samp> function preprocesses a mini\-batch of predictions and labels using the following steps:
204199

205200
1. Preprocess the images using the <samp>preprocessMiniBatchPredictors</samp> function.
206201
2. Extract the label data from the incoming cell array and concatenate into a categorical array along the second dimension.
207-
3. One\-hot encode the categorical labels into numeric arrays. Encoding into the first dimension produces an encoded array that matches the shape of the network output.
202+
3. One\-hot encode the categorical labels into numeric arrays. Encoding in the first dimension produces an encoded array that matches the shape of the network output.
208203
```matlab
209204
function [X,T] = preprocessMiniBatch(dataX,dataT)
210205
@@ -219,19 +214,20 @@ T = onehotencode(T,1);
219214
220215
end
221216
```
222-
## Mini\-Batch Predictors Preprocessing Function
217+
## Mini\-Batch Predictors Preprocessing Function
223218

224-
The <samp>preprocessMiniBatchPredictors</samp> function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenate into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to <samp>[0,1]</samp> range.
219+
The <samp>preprocessMiniBatchPredictors</samp> function preprocesses a mini\-batch of predictors by extracting the image data from the input cell array and concatenating it into a numeric array. For grayscale input, concatenating over the fourth dimension adds a third dimension to each image, to use as a singleton channel dimension. You divide by 255 to normalize the pixels to <samp>[0,1]</samp> range.
225220

226221
```matlab
227222
function X = preprocessMiniBatchPredictors(dataX)
228-
X = single(cat(4,dataX{1:end}))/255;
223+
X = (single(cat(4,dataX{1:end}))/255); % Normalizes to [0, 1]
224+
X = 2*X - 1; % Normalizes to [-1, 1].
229225
end
230226
```
231-
# References
232-
233-
\[1\] Amos, Brandon, et al. Input Convex Neural Networks. arXiv:1609.07152, arXiv, 14 June 2017. arXiv.org, https://doi.org/10.48550/arXiv.1609.07152.
234227

228+
# References
229+
\[1\] Amos, Brandon, et al. "Input Convex Neural Networks." (2017). https://doi.org/10.48550/arXiv.1609.07152.
235230

236-
*Copyright 2024 The MathWorks, Inc.*
231+
\[2\] Krizhevsky, Alex. "Learning multiple layers of features from tiny images." (2009). https://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf
237232

233+
*Copyright 2024-2025 The MathWorks, Inc.*
Binary file not shown.
Loading
Loading
Loading

0 commit comments

Comments
 (0)