-
Notifications
You must be signed in to change notification settings - Fork 184
/
Copy pathmem_moco.py
142 lines (120 loc) · 4.38 KB
/
mem_moco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import torch.nn as nn
import torch.nn.functional as F
class BaseMoCo(nn.Module):
"""base class for MoCo-style memory cache"""
def __init__(self, K=65536, T=0.07):
super(BaseMoCo, self).__init__()
self.K = K
self.T = T
self.index = 0
def _update_pointer(self, bsz):
self.index = (self.index + bsz) % self.K
def _update_memory(self, k, queue):
"""
Args:
k: key feature
queue: memory buffer
"""
with torch.no_grad():
num_neg = k.shape[0]
out_ids = torch.arange(num_neg).cuda()
out_ids = torch.fmod(out_ids + self.index, self.K).long()
queue.index_copy_(0, out_ids, k)
def _compute_logit(self, q, k, queue):
"""
Args:
q: query/anchor feature
k: key feature
queue: memory buffer
"""
# pos logit
bsz = q.shape[0]
pos = torch.bmm(q.view(bsz, 1, -1), k.view(bsz, -1, 1))
pos = pos.view(bsz, 1)
# neg logit
neg = torch.mm(queue, q.transpose(1, 0))
neg = neg.transpose(0, 1)
out = torch.cat((pos, neg), dim=1)
out = torch.div(out, self.T)
out = out.squeeze().contiguous()
return out
class RGBMoCo(BaseMoCo):
"""Single Modal (e.g., RGB) MoCo-style cache"""
def __init__(self, n_dim, K=65536, T=0.07):
super(RGBMoCo, self).__init__(K, T)
# create memory queue
self.register_buffer('memory', torch.randn(K, n_dim))
self.memory = F.normalize(self.memory)
def forward(self, q, k, q_jig=None, all_k=None):
"""
Args:
q: query on current node
k: key on current node
q_jig: jigsaw query
all_k: gather of feats across nodes; otherwise use q
"""
bsz = q.size(0)
k = k.detach()
# compute logit
queue = self.memory.clone().detach()
logits = self._compute_logit(q, k, queue)
if q_jig is not None:
logits_jig = self._compute_logit(q_jig, k, queue)
# set label
labels = torch.zeros(bsz, dtype=torch.long).cuda()
# update memory
all_k = all_k if all_k is not None else k
self._update_memory(all_k, self.memory)
self._update_pointer(all_k.size(0))
if q_jig is not None:
return logits, logits_jig, labels
else:
return logits, labels
class CMCMoCo(BaseMoCo):
"""MoCo-style memory for two modalities, e.g. in CMC"""
def __init__(self, n_dim, K=65536, T=0.07):
super(CMCMoCo, self).__init__(K, T)
# create memory queue
self.register_buffer('memory_1', torch.randn(K, n_dim))
self.register_buffer('memory_2', torch.randn(K, n_dim))
self.memory_1 = F.normalize(self.memory_1)
self.memory_2 = F.normalize(self.memory_2)
def forward(self, q1, k1, q2, k2,
q1_jig=None, q2_jig=None,
all_k1=None, all_k2=None):
"""
Args:
q1: q of modal 1
k1: k of modal 1
q2: q of modal 2
k2: k of modal 2
q1_jig: q jig of modal 1
q2_jig: q jig of modal 2
all_k1: gather of k1 across nodes; otherwise use k1
all_k2: gather of k2 across nodes; otherwise use k2
"""
bsz = q1.size(0)
k1 = k1.detach()
k2 = k2.detach()
# compute logit
queue1 = self.memory_1.clone().detach()
queue2 = self.memory_2.clone().detach()
logits1 = self._compute_logit(q1, k2, queue2)
logits2 = self._compute_logit(q2, k1, queue1)
if (q1_jig is not None) and (q2_jig is not None):
logits1_jig = self._compute_logit(q1_jig, k2, queue2)
logits2_jig = self._compute_logit(q2_jig, k1, queue1)
# set label
labels = torch.zeros(bsz, dtype=torch.long).cuda()
# update memory
all_k1 = all_k1 if all_k1 is not None else k1
all_k2 = all_k2 if all_k2 is not None else k2
assert all_k1.size(0) == all_k2.size(0)
self._update_memory(all_k1, self.memory_1)
self._update_memory(all_k2, self.memory_2)
self._update_pointer(all_k1.size(0))
if (q1_jig is not None) and (q2_jig is not None):
return logits1, logits2, logits1_jig, logits2_jig, labels
else:
return logits1, logits2, labels