-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_cnn.m
executable file
·83 lines (72 loc) · 3.22 KB
/
train_cnn.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
%% train_cnn.m
%
% Generates a bash script that trains the CNN model when executed in terminal.
% Prior to generating the script, training data is prepared and saved in:
%
% ./caffe-heatmap/model-trainer/+fusion/training_data
%
% The underlying CNN model training framework is derived from the following works:
%
% T. Pfister, J. Charles, and A. Zisserman, "Flowing ConvNets for human pose estimation in videos,"
% in Proc. IEEE Int. Conf. Comput. Vis., Dec. 2015, pp. 1913-1921.
%
% J. Charles, T. Pfister, D. Magee, D. Hogg, and A. Zisserman, "Personalizing human video pose estimation,"
% in Proc. IEEE Comput. Soc. Conf. Comput. Vis. Pattern Recognit. (CVPR), Jun. 2016, pp. 3063-3072.
%
% Many thanks to these authors.
%
% Usage:
% Run this while the current MATLAB directory is in the top-level folder of the PatientPose package (i.e. the
% location of this script). The output will generate a .sh bash script in ./caffe-heatmap/model-trainer/training_scripts/
% which trains the CNN model when executed in the terminal. To change specific options of the neural network, modify
% the file ./caffe-heatmap/model-trainer/+fusion/base_scripts/solver.prototxt prior to running this script.
%
% Inputs:
% - Folder containing the preprocessed training images
% - Ground truth pose annotations generated via label_images.m as a .mat file
%
% Outputs:
% - A bash script saved in ./caffe-heatmap/model-trainer/training_scripts/ to train the patient CNN model
%
% Translational Neuroengineering Laboratory (TNEL) @ UC San Diego
% Website: http://www.tnel.ucsd.edu
clear; close all;
dateTime = datestr(now,'mm-dd-yy_HH:MM:SS');
%% Setup & Options
run patientpose_setup
run patientpose_options
%% Training Images
% Load images being tracked
disp('Select the folder containing the preprocessed training images');
im.folder = uigetdir('','Folder containing images');
addpath(im.folder);
im.files = dir(fullfile(im.folder,'*.jpg'));
% Sort the files in natural counting order
im.names = natsortfiles({im.files.name});
%% Training Annotations
% Load ground truth annotations for training images
disp('Select training ground truth poses generated via label_images.m');
uiopen('matlab');
%% Create training files
% Create video from images
pdir = pwd;
cd(pp.input2vidloc);
images2video(im, dateTime)
cd(pdir);
% Create the training files
videoname = strcat('input2_',dateTime);
dataset = 'youtube';
[opts, folder] = load_system_options(dataset, videoname);
[filename, folder] = setup_filenames(folder, videoname, dataset);
fusion.setupFinetuningCropped(opts.cnn,dataset,videoname,...
filename.video,detections.manual.frameids,...
detections.manual.locs,detections.manual.frameids(1),...
detections.manual.locs(:,:,1),opts.imscale,...
opts.cnn.finetune.dims(1),opts.cnn.finetune.dims(2),dateTime);
% Create training scripts folder and copy to folder
if ~exist([pp.personalizeloc '/training_scripts/'])
mkdir([pp.personalizeloc '/training_scripts/']);
end
copyfile([pp.personalizeloc '+fusion/fusion_training/' dataset '/heatmap_finetuned/input2_' dateTime '/train_' dateTime '.sh'],...
[pp.personalizeloc 'training_scripts/train_cnn_' dateTime '.sh']);
error('Run the generated .sh file in terminal to train the patient CNN model.');