-
Notifications
You must be signed in to change notification settings - Fork 129
/
Copy pathinference.py
474 lines (407 loc) · 14.4 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# Copyright 2019 Image Analysis Lab, German Center for Neurodegenerative Diseases (DZNE), Bonn
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
# IMPORTS
import time
from typing import Optional
import numpy as np
import torch
import yacs.config
from numpy import typing as npt
from pandas import DataFrame
from torch.utils.data import DataLoader
from torchvision import transforms
from FastSurferCNN.data_loader.augmentation import ToTensorTest
from FastSurferCNN.data_loader.data_utils import map_prediction_sagittal2full
from FastSurferCNN.data_loader.dataset import MultiScaleOrigDataThickSlices
from FastSurferCNN.models.networks import build_model
from FastSurferCNN.utils import logging
logger = logging.getLogger(__name__)
class Inference:
"""Model evaluation class to run inference using FastSurferCNN.
Attributes
----------
permute_order : Dict[str, Tuple[int, int, int, int]]
Permutation order for axial, coronal, and sagittal
device : Optional[torch.device])
Device specification for distributed computation usage.
default_device : torch.device
Default device specification for distributed computation usage.
cfg : yacs.config.CfgNode
Configuration Node
model_parallel : bool
Option for parallel run
model : torch.nn.Module
Neural network model
model_name : str
Name of the model
alpha : Dict[str, float]
Alpha values for different planes.
post_prediction_mapping_hook
Hook for post prediction mapping.
Methods
-------
setup_model
Set up the initial model
set_cfg
Set configuration node
to
Moves and/or casts the parameters and buffers.
load_checkpoint
Load the checkpoint
eval
Evaluate predictions
run
Run the loaded model
"""
permute_order: dict[str, tuple[int, int, int, int]]
device: torch.device | None
default_device: torch.device
def __init__(
self,
cfg: yacs.config.CfgNode,
device: torch.device,
ckpt: str = "",
lut: None | str | np.ndarray | DataFrame = None,
):
"""
Construct Inference object.
Parameters
----------
cfg : yacs.config.CfgNode
Configuration Node.
device : torch.device
Device specification for distributed computation usage.
ckpt : str
String or os.PathLike object containing the name to the checkpoint file (Default value = "").
lut : str, np.ndarray, DataFrame, optional
Lookup table for mapping (Default value = None).
"""
# Set random seed from configs.
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
self.cfg = cfg
# Switch on denormal flushing for faster CPU processing
# seems to have less of an effect on VINN than old CNN
torch.set_flush_denormal(True)
self.default_device = device
# Options for parallel run
self.model_parallel = (
torch.cuda.device_count() > 1
and self.default_device.type == "cuda"
and self.default_device.index is None
)
# Initial model setup
self.model = None
self._model_not_init = None
self.setup_model(cfg, device=self.default_device)
self.model_name = self.cfg.MODEL.MODEL_NAME
self.alpha = {"sagittal": 0.2}
self.permute_order = {
"axial": (3, 0, 2, 1),
"coronal": (2, 3, 0, 1),
"sagittal": (0, 3, 2, 1),
}
self.lut = lut
# Initial checkpoint loading
if ckpt:
# this also moves the model to the para
self.load_checkpoint(ckpt)
def setup_model(self, cfg=None, device: torch.device = None):
"""
Set up the model.
Parameters
----------
cfg : yacs.config.CfgNode
Configuration Node (Default value = None).
device : torch.device
Device specification for distributed computation usage. (Default value = None).
"""
if cfg is not None:
self.cfg = cfg
if device is None:
device = self.default_device
# Set up model
self._model_not_init = build_model(
self.cfg
) # ~ model = FastSurferCNN(params_network)
self._model_not_init.to(device)
self.device = None
def set_cfg(self, cfg: yacs.config.CfgNode):
"""
Set the configuration node.
Parameters
----------
cfg : yacs.config.CfgNode
Configuration node.
"""
self.cfg = cfg
def to(self, device: torch.device | None = None):
"""
Move and/or cast the parameters and buffers.
Parameters
----------
device : Optional[torch.device]
The desired device of the parameters and buffers in this module (Default value = None).
"""
if self.model_parallel:
raise RuntimeError(
"Moving the model to other devices is not supported for multi-device models."
)
_device = self.default_device if device is None else device
self.device = _device
self.model.to(device=_device)
def load_checkpoint(self, ckpt: str | os.PathLike):
"""
Load the checkpoint and set device and model.
Parameters
----------
ckpt : Union[str, os.PathLike]
String or os.PathLike object containing the name to the checkpoint file.
"""
logger.info(f"Loading checkpoint {ckpt}")
self.model = self._model_not_init
# If device is None, the model has never been loaded (still in random initial configuration)
if self.device is None:
self.device = self.default_device
# workaround for mps (directly loading to map_location=mps results in zeros)
device = self.device
if self.device.type == "mps":
self.model.to("cpu")
device = "cpu"
else:
# make sure the model is, where it is supposed to be
self.model.to(self.device)
# WARNING: weights_only=False can cause unsafe code execution, but here the
# checkpoint can be considered to be from a safe source
model_state = torch.load(ckpt, map_location=device, weights_only=False)
self.model.load_state_dict(model_state["model_state"])
# workaround for mps (move the model back to mps)
if self.device.type == "mps":
self.model.to(self.device)
if self.model_parallel:
self.model = torch.nn.DataParallel(self.model)
def get_modelname(self) -> str:
"""
Return the model name.
Returns
-------
str
The name of the model.
"""
return self.model_name
def get_cfg(self) -> yacs.config.CfgNode:
"""
Return the configurations.
Returns
-------
yacs.config.CfgNode
Configuration node.
"""
return self.cfg
def get_num_classes(self) -> int:
"""
Return the number of classes.
Returns
-------
int
The number of classes.
"""
return self.cfg.MODEL.NUM_CLASSES
def get_plane(self) -> str:
"""
Return the plane.
Returns
-------
str
The plane used in the model.
"""
return self.cfg.DATA.PLANE
def get_model_height(self) -> int:
"""
Return the model height.
Returns
-------
int
The height of the model.
"""
return self.cfg.MODEL.HEIGHT
def get_model_width(self) -> int:
"""
Return the model width.
Returns
-------
int
The width of the model.
"""
return self.cfg.MODEL.WIDTH
def get_max_size(self) -> int | tuple[int, int]:
"""
Return the max size.
Returns
-------
int | tuple[int, int]
The maximum size, either a single value or a tuple (width, height).
"""
if self.cfg.MODEL.OUT_TENSOR_WIDTH == self.cfg.MODEL.OUT_TENSOR_HEIGHT:
return self.cfg.MODEL.OUT_TENSOR_WIDTH
else:
return self.cfg.MODEL.OUT_TENSOR_WIDTH, self.cfg.MODEL.OUT_TENSOR_HEIGHT
def get_device(self) -> torch.device:
"""
Return the device.
Returns
-------
torch.device
The device used for computation.
"""
return self.device
@torch.no_grad()
def eval(
self,
init_pred: torch.Tensor,
val_loader: DataLoader,
*,
out_scale: Optional = None,
out: torch.Tensor | None = None,
) -> torch.Tensor:
"""Perform prediction and inplace-aggregate views into pred_prob.
Parameters
----------
init_pred : torch.Tensor
Initial prediction.
val_loader : DataLoader
Validation loader.
out_scale : Optional
Output scale (Default value = None).
out : Optional[torch.Tensor]
Previous prediction tensor (Default value = None).
Returns
-------
torch.Tensor
Prediction probability tensor.
"""
self.model.eval()
# we should check here, whether the DataLoader is a Random or a SequentialSampler, but we cannot easily.
if not isinstance(val_loader.sampler, torch.utils.data.SequentialSampler):
logger.warning(
"The Validation loader seems to not use the SequentialSampler. This might interfere with "
"the assumed sorting of batches."
)
start_index = 0
plane = self.cfg.DATA.PLANE
index_of_current_plane = self.permute_order[plane].index(0)
target_shape = init_pred.shape
ii = [slice(None) for _ in range(4)]
pred_ii = tuple(slice(i) for i in target_shape[:3])
from tqdm import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
if out is None:
out = init_pred.detach().clone()
log_batch_idx = None
with logging_redirect_tqdm():
try:
for batch_idx, batch in tqdm(
enumerate(val_loader), total=len(val_loader), unit="batch"
):
log_batch_idx = batch_idx
# move data to the model device
images, scale_factors = batch["image"].to(self.device), batch[
"scale_factor"
].to(self.device)
# predict the current batch, outputs logits
pred = self.model(images, scale_factors, out_scale)
batch_size = pred.shape[0]
end_index = start_index + batch_size
# check if we need a special mapping (e.g. as for sagittal)
if self.get_plane() == "sagittal":
pred = map_prediction_sagittal2full(
pred, num_classes=self.get_num_classes(), lut=self.lut
)
# permute the prediction into the out slice order
pred = pred.permute(*self.permute_order[plane]).to(
out.device
) # the to-operation is implicit
# cut prediction to the image size
pred = pred[pred_ii]
# add prediction logits into the output (same as multiplying probabilities)
ii[index_of_current_plane] = slice(start_index, end_index)
out[tuple(ii)].add_(pred, alpha=self.alpha.get(plane, 0.4))
start_index = end_index
except:
logger.exception(
f"Exception in batch {log_batch_idx} of {plane} inference."
)
raise
else:
logger.info(
f"Inference on {batch_idx + 1} batches for {plane} successful"
)
return out
@torch.no_grad()
def run(
self,
init_pred: torch.Tensor,
img_filename: str,
orig_data: npt.NDArray,
orig_zoom: npt.NDArray,
out: torch.Tensor | None = None,
out_res: int | None = None,
batch_size: int = None,
) -> torch.Tensor:
"""
Run the loaded model on the data (T1) from orig_data and
img_filename (for messages only) with scale factors orig_zoom.
Parameters
----------
init_pred : torch.Tensor
Initial prediction.
img_filename : str
Original image filename.
orig_data : npt.NDArray
Original image data.
orig_zoom : npt.NDArray
Original zoom.
out : Optional[torch.Tensor]
Updated output tensor (Default = None).
out_res : Optional[int]
Output resolution (Default value = None).
batch_size : int
Batch size (Default = None).
Returns
-------
toch.Tensor
Prediction probability tensor.
"""
# Set up DataLoader
test_dataset = MultiScaleOrigDataThickSlices(
orig_data,
orig_zoom,
self.cfg,
transforms=transforms.Compose([ToTensorTest()]),
)
test_data_loader = DataLoader(
dataset=test_dataset,
shuffle=False,
batch_size=self.cfg.TEST.BATCH_SIZE if batch_size is None else batch_size,
)
# Run evaluation
start = time.time()
out = self.eval(init_pred, test_data_loader, out=out, out_scale=out_res)
time_delta = time.time() - start
logger.info(
f"{self.cfg.DATA.PLANE.capitalize()} inference on {img_filename} finished in "
f"{time_delta:0.4f} seconds"
)
return out