|
| 1 | +function net = deSpeckNet_Init() |
| 2 | + |
| 3 | +% Create DAGNN object |
| 4 | +net = dagnn.DagNN(); |
| 5 | + |
| 6 | +% conv + relu |
| 7 | +blockNum = 1; |
| 8 | +inVar = 'input'; |
| 9 | +channel= 1; % grayscale image |
| 10 | +dims = [3,3,channel,64]; |
| 11 | +pad = [1,1]; |
| 12 | +dilate = [1,1]; |
| 13 | +stride = [1,1]; |
| 14 | +lr = [1,1]; |
| 15 | +%FCN clean |
| 16 | +[net, inVar1, blockNum] = addConv(net, blockNum, inVar, dims, pad,dilate, stride, lr); |
| 17 | +[net, inVar1, blockNum] = addReLU(net, blockNum, inVar1); |
| 18 | + |
| 19 | +for i = 1:15 |
| 20 | + % conv + bn + relu |
| 21 | + dims0 = [3,3,64,64]; |
| 22 | + [net, inVar1, blockNum] = addConv(net, blockNum, inVar1, dims0, pad,dilate, stride, lr); |
| 23 | + n_ch = dims0(4); |
| 24 | + [net, inVar1, blockNum] = addBnorm(net, blockNum, inVar1, n_ch); |
| 25 | + [net, inVar1, blockNum] = addReLU(net, blockNum, inVar1); |
| 26 | +end |
| 27 | + |
| 28 | +% conv |
| 29 | +dims1 = [3,3,64,channel]; |
| 30 | +[net, inVar5, blockNum] = addConv(net, blockNum, inVar1, dims1, pad,dilate, stride, lr); |
| 31 | + |
| 32 | +%__________________________________________________________________________ |
| 33 | +%FCN noise |
| 34 | + |
| 35 | +[net, inVar8, blockNum] = addConv(net, blockNum, inVar, dims, pad,dilate, stride, lr); |
| 36 | +[net, inVar8, blockNum] = addReLU(net, blockNum, inVar8); |
| 37 | + |
| 38 | +for i = 1:15 |
| 39 | + % conv + bn + relu |
| 40 | + [net, inVar8, blockNum] = addConv(net, blockNum, inVar8, dims0, pad,dilate, stride, lr); |
| 41 | + n_ch = dims0(4); |
| 42 | + [net, inVar8, blockNum] = addBnorm(net, blockNum, inVar8, n_ch); |
| 43 | + [net, inVar8, blockNum] = addReLU(net, blockNum, inVar8); |
| 44 | +end |
| 45 | + |
| 46 | +% conv |
| 47 | +[net, inVar13, blockNum] = addConv(net, blockNum, inVar8, dims1, pad,dilate, stride, lr); |
| 48 | + |
| 49 | + |
| 50 | +% % % Multiply and reconstruct noisy image |
| 51 | +inVarr = {inVar13,inVar5}; |
| 52 | +[net, inVar30, blockNum] = addMultiply(net, blockNum, inVarr); |
| 53 | +% |
| 54 | + |
| 55 | +outputName = 'prediction'; |
| 56 | +net.renameVar(inVar5,outputName) |
| 57 | + |
| 58 | +%__________________________________________________________________________ |
| 59 | + |
| 60 | + |
| 61 | +% loss clean |
| 62 | +net.addLayer('loss', dagnn.Loss('loss','L2'), {'prediction','label'}, {'objective'},{}); |
| 63 | +net.vars(net.getVarIndex('prediction')).precious = 1; |
| 64 | + |
| 65 | + |
| 66 | +outputName1 = 'prediction1'; %Final noisy image reconstruction |
| 67 | +net.renameVar(inVar30,outputName1) |
| 68 | + |
| 69 | +% loss noisy |
| 70 | +net.addLayer('loss1', dagnn.Loss('loss','L2'), {'prediction1','input'}, {'objective1'},{}); |
| 71 | +net.vars(net.getVarIndex('prediction1')).precious = 1; |
| 72 | + |
| 73 | +end |
| 74 | + |
| 75 | + |
| 76 | +% Add a multiply layer |
| 77 | +function [net, inVar, blockNum] = addMultiply(net, blockNum, inVar) |
| 78 | + |
| 79 | +outVar = sprintf('mult%d', blockNum); |
| 80 | +layerCur = sprintf('mult%d', blockNum); |
| 81 | + |
| 82 | +block = dagnn.Multiply(); |
| 83 | +net.addLayer(layerCur, block, inVar, {outVar},{}); |
| 84 | + |
| 85 | +inVar = outVar; |
| 86 | +blockNum = blockNum + 1; |
| 87 | +end |
| 88 | + |
| 89 | + |
| 90 | +% Add a relu layer |
| 91 | +function [net, inVar, blockNum] = addReLU(net, blockNum, inVar) |
| 92 | + |
| 93 | +outVar = sprintf('relu%d', blockNum); |
| 94 | +layerCur = sprintf('relu%d', blockNum); |
| 95 | + |
| 96 | +block = dagnn.ReLU('leak',0); |
| 97 | +net.addLayer(layerCur, block, {inVar}, {outVar},{}); |
| 98 | + |
| 99 | +inVar = outVar; |
| 100 | +blockNum = blockNum + 1; |
| 101 | +end |
| 102 | + |
| 103 | + |
| 104 | +% Add a bnorm layer |
| 105 | +function [net, inVar, blockNum] = addBnorm(net, blockNum, inVar, n_ch) |
| 106 | + |
| 107 | +trainMethod = 'adam'; |
| 108 | +outVar = sprintf('bnorm%d', blockNum); |
| 109 | +layerCur = sprintf('bnorm%d', blockNum); |
| 110 | + |
| 111 | +params={[layerCur '_g'], [layerCur '_b'], [layerCur '_m']}; |
| 112 | +net.addLayer(layerCur, dagnn.BatchNorm('numChannels', n_ch), {inVar}, {outVar},params) ; |
| 113 | + |
| 114 | +pidx = net.getParamIndex({[layerCur '_g'], [layerCur '_b'], [layerCur '_m']}); |
| 115 | +b_min = 0.025; |
| 116 | +net.params(pidx(1)).value = clipping(sqrt(2/(9*n_ch))*randn(n_ch,1,'single'),b_min); |
| 117 | +net.params(pidx(1)).learningRate= 1; |
| 118 | +net.params(pidx(1)).weightDecay = 0; |
| 119 | +net.params(pidx(1)).trainMethod = trainMethod; |
| 120 | + |
| 121 | +net.params(pidx(2)).value = zeros(n_ch, 1, 'single'); |
| 122 | +net.params(pidx(2)).learningRate= 1; |
| 123 | +net.params(pidx(2)).weightDecay = 0; |
| 124 | +net.params(pidx(2)).trainMethod = trainMethod; |
| 125 | + |
| 126 | +net.params(pidx(3)).value = [zeros(n_ch,1,'single'), 0.01*ones(n_ch,1,'single')]; |
| 127 | +net.params(pidx(3)).learningRate= 1; |
| 128 | +net.params(pidx(3)).weightDecay = 0; |
| 129 | +net.params(pidx(3)).trainMethod = 'average'; |
| 130 | + |
| 131 | +inVar = outVar; |
| 132 | +blockNum = blockNum + 1; |
| 133 | +end |
| 134 | + |
| 135 | + |
| 136 | +% add a Conv layer |
| 137 | +function [net, inVar, blockNum] = addConv(net, blockNum, inVar, dims, pad, dilate, stride, lr) |
| 138 | +opts.cudnnWorkspaceLimit = 1024*1024*1024*2; % 2GB |
| 139 | +convOpts = {'CudnnWorkspaceLimit', opts.cudnnWorkspaceLimit} ; |
| 140 | +trainMethod = 'adam'; |
| 141 | + |
| 142 | +outVar = sprintf('conv%d', blockNum); |
| 143 | +layerCur = sprintf('conv%d', blockNum); |
| 144 | + |
| 145 | +convBlock = dagnn.Conv('size', dims, 'pad', pad, 'dilate', dilate, 'stride', stride, ... |
| 146 | + 'hasBias', true, 'opts', convOpts); |
| 147 | + |
| 148 | +net.addLayer(layerCur, convBlock, {inVar}, {outVar},{[layerCur '_f'], [layerCur '_b']}); |
| 149 | + |
| 150 | +f = net.getParamIndex([layerCur '_f']) ; |
| 151 | +sc = sqrt(2/(dims(1)*dims(2)*max(dims(3), dims(4)))) ; %improved Xavier |
| 152 | +net.params(f).value = sc*randn(dims, 'single') ; |
| 153 | +net.params(f).learningRate = lr(1); |
| 154 | +net.params(f).weightDecay = 1; |
| 155 | +net.params(f).trainMethod = trainMethod; |
| 156 | + |
| 157 | +f = net.getParamIndex([layerCur '_b']) ; |
| 158 | +net.params(f).value = zeros(dims(4), 1, 'single'); |
| 159 | +net.params(f).learningRate = lr(2); |
| 160 | +net.params(f).weightDecay = 1; |
| 161 | +net.params(f).trainMethod = trainMethod; |
| 162 | + |
| 163 | +inVar = outVar; |
| 164 | +blockNum = blockNum + 1; |
| 165 | +end |
| 166 | + |
| 167 | + |
| 168 | +function A = clipping(A,b) |
| 169 | +A(A>=0&A<b) = b; |
| 170 | +A(A<0&A>-b) = -b; |
| 171 | +end |
0 commit comments