Skip to content

Commit 045f35a

Browse files
Initial commit
0 parents  commit 045f35a

20 files changed

+3373
-0
lines changed

Readme.md

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Introduction
2+
3+
This repository is the formal implement of our paper titled “[Robust Domain Misinformation Detection via Multi-modal Feature Alignment](https://ieeexplore.ieee.org/abstract/document/10288548/)”. The contribution of this work can be summarized as follows:
4+
5+
1. A unified framework that tackles the domain generalization (target domain data is unavailable) and domain adaptation tasks (target domain data is available). This is necessary as obtaining sufficient unlabeled data in the target domain at an early stage of misinformation dissemination is difficult.
6+
2. Inter-domain and cross-modality alignment modules that reduce the domain shift and the modality gap. These modules aim at learning rich features that allow misinformation detection. Both modules are plug-and-play and have the potential to be applied to other multi-modal tasks.
7+
8+
Additionally, we believe that the multimodal generalization algorithms proposed in our work can be used in other multimodal tasks. If you have some questions related to this paper, please feel no hesitate to ask me.
9+
10+
# To run our code
11+
12+
1. download the dataset and pretrained models from Onedrive and unzip them in the project file.
13+
14+
2. drive_outmodel.py is the main file to drive our algorithms. Please remove the codes related to comel package that enable efficient management of ML experiments or add your api_key and other parameters in the below codes in this file:
15+
16+
```python
17+
experiment = Experiment(
18+
api_key="",
19+
project_name="",
20+
workspace="",
21+
)
22+
```
23+
24+
3. Multimodal JMMD, in our work, devised for multimodal generalization tasks, can capture cross-modal correlations among multiple modalities with theoretical guarantees. For better implement, I advise to use the implement of MMD in domainbed that fix the parameter of kernels and only adjust the weight of JMMD loss function $\lambda_1$. Otherwise, you can just use my implement to set the kernel manually.
25+
26+
4. At last, you can run our codes as below:
27+
28+
```
29+
sh multi_out_model.sh
30+
```
31+
32+
# Citation
33+
34+
If you find this repository helpful, please cite our paper:
35+
36+
```
37+
@ARTICLE{10288548,
38+
author={Liu, Hui and Wang, Wenya and Sun, Hao and Rocha, Anderson and Li, Haoliang},
39+
journal={IEEE Transactions on Information Forensics and Security},
40+
title={Robust Domain Misinformation Detection via Multi-modal Feature Alignment},
41+
year={2023},
42+
volume={},
43+
number={},
44+
pages={1-1},
45+
doi={10.1109/TIFS.2023.3326368}}
46+
```
47+
48+
If you have interest in multimodal misinformation detection, another paper of me on multimodal misinformation task can help you https://arxiv.org/abs/2305.05964. Despite accepted by Funding, this paper got three strong accepts :) :) :) . So it can work as a good reference, haha.
49+
50+
```
51+
@inproceedings{DBLP:conf/acl/LiuWL23,
52+
author = {Hui Liu and
53+
Wenya Wang and
54+
Haoliang Li},
55+
editor = {Anna Rogers and
56+
Jordan L. Boyd{-}Graber and
57+
Naoaki Okazaki},
58+
title = {Interpretable Multimodal Misinformation Detection with Logic Reasoning},
59+
booktitle = {Findings of the Association for Computational Linguistics: {ACL} 2023,
60+
Toronto, Canada, July 9-14, 2023},
61+
pages = {9781--9796},
62+
publisher = {Association for Computational Linguistics},
63+
year = {2023},
64+
url = {https://doi.org/10.18653/v1/2023.findings-acl.620},
65+
doi = {10.18653/V1/2023.FINDINGS-ACL.620},
66+
timestamp = {Thu, 10 Aug 2023 12:35:42 +0200},
67+
biburl = {https://dblp.org/rec/conf/acl/LiuWL23.bib},
68+
bibsource = {dblp computer science bibliography, https://dblp.org}
69+
}
70+
```
Binary file not shown.

baseline/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .download_pretrain import *
2+
from .generate_vocab import *
3+
from .mvae import *
4+
from .textcnn import *
5+
from .da_baseline import *
6+
from .dg_baseline import *

baseline/download_pretrain.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from transformers import AutoConfig, AutoModel, AutoTokenizer
2+
3+
model_name = "roberta-base"
4+
model_path = "./pretrain_model/roberta"
5+
tokenizer = AutoTokenizer.from_pretrained(model_name)
6+
model = AutoModel.from_pretrained(model_name)
7+
config = AutoConfig.from_pretrained(model_name)
8+
9+
tokenizer.save_pretrained(model_path)
10+
model.save_pretrained(model_path)
11+
config.save_pretrained(model_path)

baseline/generate_vocab.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import torch
2+
from utils import PhemeSet, TwitterSet
3+
from collections import Counter
4+
from torchtext.data.utils import get_tokenizer
5+
from torchtext.vocab import build_vocab_from_iterator
6+
7+
"""
8+
This is for extracting vocab for both dataset
9+
"""
10+
11+
12+
if __name__ == '__main__':
13+
twitter_set = TwitterSet(json_path="../final_twitter.json", img_path="../twitter/images",
14+
type=0, events=["sandy", "boston", "sochi", "malaysia"], visual_type='resnet',
15+
stage='train')
16+
17+
twitter_vocab_path = "../vocab/twitter_vocab.pt"
18+
pheme_vocab_path = "../vocab/pheme_vocab.pt"
19+
20+
print(twitter_set[0][0])
21+
tokenizer = get_tokenizer("spacy")
22+
lines = []
23+
for i in range(len(twitter_set)):
24+
lines.append(tokenizer(twitter_set[i][0].strip()))
25+
line_iter = iter(lines)
26+
vocab = build_vocab_from_iterator(line_iter, specials=["<unk>", '<pad>'], min_freq=5)
27+
vocab.set_default_index(vocab['<unk>'])
28+
print(len(vocab))
29+
# torch.save(vocab, pheme_vocab_path)
30+
torch.save(vocab, twitter_vocab_path)
31+

baseline/textcnn.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
import os
4+
import torch
5+
from transformers import AutoConfig, AutoModel
6+
7+
"""
8+
TextCNN is for uni-modal classification or just textual feature extractor (according to setting num_classes parameter)
9+
"""
10+
11+
12+
class TextCNN(nn.Module):
13+
def __init__(self, kernel_sizes, num_filters, num_classes, d_prob, mode='rand', dataset_name="Pheme"):
14+
"""
15+
16+
:param kernel_sizes:
17+
:param num_filters:
18+
:param num_classes:
19+
:param d_prob:
20+
:param mode: rand,roberta-yes,roberta-non, bert-yes, bert-non
21+
:param path_saved:
22+
"""
23+
24+
super(TextCNN, self).__init__()
25+
self.kernel_sizes = kernel_sizes
26+
self.num_filters = num_filters
27+
self.num_classes = num_classes
28+
self.d_prob = d_prob
29+
# roberta-non bert-non bert-yes bert-yes rand
30+
self.mode = mode
31+
self.vocab = None
32+
self.dataset_name = dataset_name
33+
self.vocab_size = 1000
34+
self.embedding_dim = 100
35+
self.embedding = None
36+
# Bert rand mode need padding_idx, Bert/roberta does not need
37+
self.load_embeddings()
38+
self.conv = nn.ModuleList([nn.Conv1d(in_channels=self.embedding_dim,
39+
out_channels=num_filters,
40+
kernel_size=k, stride=1) for k in kernel_sizes])
41+
self.dropout = nn.Dropout(d_prob)
42+
self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)
43+
44+
def forward(self, x):
45+
# batch_size, sequence_length = x.shape
46+
# b*l*dim->b*dim*l
47+
x = self.embedding(x).transpose(1, 2)
48+
x = [F.relu(conv(x)) for conv in self.conv]
49+
x = [F.max_pool1d(c, c.size(-1)).squeeze(dim=-1) for c in x]
50+
x = torch.cat(x, dim=1)
51+
x = self.fc(self.dropout(x))
52+
return x.squeeze()
53+
54+
def load_embeddings(self):
55+
if self.mode == 'rand':
56+
if self.dataset_name == "Pheme":
57+
path_saved = "/data/sunhao/robustfakenews/dataset/vocab/pheme_vocab.pt"
58+
elif self.dataset_name == "Twitter":
59+
path_saved = "/data/sunhao/robustfakenews/dataset/vocab/twitter_vocab.pt"
60+
else:
61+
print('When Randomly initialized embeddings, the vocabulary is wrong')
62+
exit(0)
63+
vocab = torch.load(path_saved)
64+
self.vocab_size = len(vocab)
65+
self.embedding_dim = 100
66+
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim, padding_idx=vocab['<pad>'])
67+
self.embedding.weight.data.requires_grad = True
68+
del vocab
69+
print('Randomly initialized embeddings are used.')
70+
else:
71+
# /data/sunhao/robustfakenews/pretrain_model
72+
mode = self.mode.split("-")
73+
assert len(mode) == 2
74+
path_saved = "/data/sunhao/robustfakenews/pretrain_model"
75+
if mode[0] == 'roberta':
76+
config = AutoConfig.from_pretrained(os.path.join(path_saved, "roberta"))
77+
roberta = AutoModel.from_pretrained(os.path.join(path_saved, "roberta"), config=config)
78+
weight = roberta.get_input_embeddings().weight
79+
self.vocab_size = weight.shape[0]
80+
self.embedding_dim = weight.shape[1]
81+
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim).from_pretrained(
82+
weight)
83+
# self.embedding.weight.data.copy_(roberta.get_input_embeddings().weight)
84+
del roberta, config, weight
85+
elif mode[0] == 'bert':
86+
config = AutoConfig.from_pretrained(os.path.join(path_saved, "bert"))
87+
bert = AutoModel.from_pretrained(os.path.join(path_saved, "bert"), config=config)
88+
weight = bert.get_input_embeddings().weight
89+
self.vocab_size = weight.shape[0]
90+
self.embedding_dim = weight.shape[1]
91+
self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim).from_pretrained(weight)
92+
del bert, config, weight
93+
94+
else:
95+
raise ValueError('Unexpected value of mode. Please choose from roberta-non, roberta-yes, rand.')
96+
97+
if mode[1] == 'non':
98+
self.embedding.weight.data.requires_grad = False
99+
print('Loaded pretrained embeddings, weights are not trainable.')
100+
101+
elif mode[1] == 'yes':
102+
self.embedding.weight.data.requires_grad = True
103+
print('Loaded pretrained embeddings, weights are trainable.')
104+
105+
else:
106+
raise ValueError('Unexpected value of mode[1].')

0 commit comments

Comments
 (0)