forked from lightaime/deep_gcns_torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
executable file
·67 lines (51 loc) · 2.18 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import __init__
from tqdm import tqdm
import numpy as np
import torch
import torch_geometric.datasets as GeoData
from torch_geometric.data import DenseDataLoader
import torch_geometric.transforms as T
from config import OptInit
from architecture import DenseDeepGCN
from utils.ckpt_util import load_pretrained_models
import logging
def main():
opt = OptInit().get_args()
logging.info('===> Creating dataloader...')
test_dataset = GeoData.S3DIS(opt.data_dir, opt.area, train=False, pre_transform=T.NormalizeScale())
test_loader = DenseDataLoader(test_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=0)
opt.n_classes = test_loader.dataset.num_classes
if opt.no_clutter:
opt.n_classes -= 1
logging.info('===> Loading the network ...')
model = DenseDeepGCN(opt).to(opt.device)
model, opt.best_value, opt.epoch = load_pretrained_models(model, opt.pretrained_model, opt.phase)
logging.info('===> Start Evaluation ...')
test(model, test_loader, opt)
def test(model, loader, opt):
Is = np.empty((len(loader), opt.n_classes))
Us = np.empty((len(loader), opt.n_classes))
model.eval()
with torch.no_grad():
for i, data in enumerate(tqdm(loader)):
data = data.to(opt.device)
inputs = torch.cat((data.pos.transpose(2, 1).unsqueeze(3), data.x.transpose(2, 1).unsqueeze(3)), 1)
gt = data.y
out = model(inputs)
pred = out.max(dim=1)[1]
pred_np = pred.cpu().numpy()
target_np = gt.cpu().numpy()
for cl in range(opt.n_classes):
cur_gt_mask = (target_np == cl)
cur_pred_mask = (pred_np == cl)
I = np.sum(np.logical_and(cur_pred_mask, cur_gt_mask), dtype=np.float32)
U = np.sum(np.logical_or(cur_pred_mask, cur_gt_mask), dtype=np.float32)
Is[i, cl] = I
Us[i, cl] = U
ious = np.divide(np.sum(Is, 0), np.sum(Us, 0))
ious[np.isnan(ious)] = 1
for cl in range(opt.n_classes):
logging.info("===> mIOU for class {}: {}".format(cl, ious[cl]))
logging.info("===> mIOU is {}".format(np.mean(ious)))
if __name__ == '__main__':
main()