forked from DenisDsh/PyTorch-Deep-CORAL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcoral.py
37 lines (26 loc) · 923 Bytes
/
coral.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
def coral(source, target):
d = source.size(1) # dim vector
source_c = compute_covariance(source)
target_c = compute_covariance(target)
loss = torch.sum(torch.mul((source_c - target_c), (source_c - target_c)))
loss = loss / (4 * d * d)
return loss
def compute_covariance(input_data):
"""
Compute Covariance matrix of the input data
"""
n = input_data.size(0) # batch_size
# Check if using gpu or cpu
if input_data.is_cuda:
device = torch.device('cuda')
else:
device = torch.device('cpu')
id_row = torch.ones(n).resize(1, n).to(device=device)
sum_column = torch.mm(id_row, input_data)
mean_column = torch.div(sum_column, n)
term_mul_2 = torch.mm(mean_column.t(), mean_column)
d_t_d = torch.mm(input_data.t(), input_data)
c = torch.add(d_t_d, (-1 * term_mul_2)) * 1 / (n - 1)
return c