Skip to content

Commit d4a3484

Browse files
committed
initial upload
1 parent dc0a0c6 commit d4a3484

31 files changed

+1158
-0
lines changed

.gitignore

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
env/
12+
build/
13+
develop-eggs/
14+
dist/
15+
downloads/
16+
eggs/
17+
.eggs/
18+
lib/
19+
lib64/
20+
parts/
21+
sdist/
22+
var/
23+
wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
28+
# PyInstaller
29+
# Usually these files are written by a python script from a template
30+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
31+
*.manifest
32+
*.spec
33+
34+
# Installer logs
35+
pip-log.txt
36+
pip-delete-this-directory.txt
37+
38+
# Unit test / coverage reports
39+
htmlcov/
40+
.tox/
41+
.coverage
42+
.coverage.*
43+
.cache
44+
nosetests.xml
45+
coverage.xml
46+
*.cover
47+
.hypothesis/
48+
49+
# Translations
50+
*.mo
51+
*.pot
52+
53+
# Django stuff:
54+
*.log
55+
local_settings.py
56+
57+
# Flask stuff:
58+
instance/
59+
.webassets-cache
60+
61+
# Scrapy stuff:
62+
.scrapy
63+
64+
# Sphinx documentation
65+
docs/_build/
66+
67+
# PyBuilder
68+
target/
69+
70+
# Jupyter Notebook
71+
.ipynb_checkpoints
72+
73+
# pyenv
74+
.python-version
75+
76+
# celery beat schedule file
77+
celerybeat-schedule
78+
79+
# SageMath parsed files
80+
*.sage.py
81+
82+
# dotenv
83+
.env
84+
85+
# virtualenv
86+
.venv
87+
venv/
88+
ENV/
89+
90+
# Spyder project settings
91+
.spyderproject
92+
.spyproject
93+
94+
# Rope project settings
95+
.ropeproject
96+
97+
# mkdocs documentation
98+
/site
99+
100+
# mypy
101+
.mypy_cache/
102+
103+
# Pycharm
104+
.idea/
105+
106+
# Experiments
107+
experiments/*
108+
!experiments/.keep
109+
110+
# Pretrained Weights
111+
pretrained_weights/*
112+
!pretrained_weights/.keep
113+
114+
# Data
115+
data/*
116+
!data/.keep

__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__authors__ = ["Mo'men", "Hager"]

agents/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
import sys
3+
4+
path = os.path.dirname(os.path.abspath(__file__))
5+
6+
for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']:
7+
mod = __import__('.'.join([__name__, py]), fromlist=[py])
8+
classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)]
9+
for cls in classes:
10+
setattr(sys.modules[__name__], cls.__name__, cls)

agents/base.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
"""
2+
The Base Agent class, where all other agents inherit from, that contains definitions for all the necessary functions
3+
"""
4+
import logging
5+
6+
7+
class BaseAgent:
8+
"""
9+
This base class will contain the base functions to be overloaded by any agent you will implement.
10+
"""
11+
12+
def __init__(self, config):
13+
self.config = config
14+
self.logger = logging.getLogger("Agent")
15+
16+
def load_checkpoint(self, file_name):
17+
"""
18+
Latest checkpoint loader
19+
:param file_name: name of the checkpoint file
20+
:return:
21+
"""
22+
raise NotImplementedError
23+
24+
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
25+
"""
26+
Checkpoint saver
27+
:param file_name: name of the checkpoint file
28+
:param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far
29+
:return:
30+
"""
31+
raise NotImplementedError
32+
33+
def run(self):
34+
"""
35+
The main operator
36+
:return:
37+
"""
38+
raise NotImplementedError
39+
40+
def train(self):
41+
"""
42+
Main training loop
43+
:return:
44+
"""
45+
raise NotImplementedError
46+
47+
def train_one_epoch(self):
48+
"""
49+
One epoch of training
50+
:return:
51+
"""
52+
raise NotImplementedError
53+
54+
def validate(self):
55+
"""
56+
One cycle of model validation
57+
:return:
58+
"""
59+
raise NotImplementedError
60+
61+
def finalize(self):
62+
"""
63+
Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
64+
:return:
65+
"""
66+
raise NotImplementedError

agents/bert.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
from tqdm import tqdm
3+
4+
import random
5+
6+
import torch
7+
from torch import nn
8+
from torch.backends import cudnn
9+
from torch.autograd import Variable
10+
import torchvision.utils as vutils
11+
12+
from agent.base import BaseAgent
13+
14+
cudnn.benchmark = True
15+
16+
class BERTAgent(BaseAgent):
17+
18+
def __init__(self, config):
19+
super().__init__(config)
20+
21+
def load_checkpoint(self, file_name):
22+
"""
23+
Latest checkpoint loader
24+
:param file_name: name of the checkpoint file
25+
:return:
26+
"""
27+
raise NotImplementedError
28+
29+
def save_checkpoint(self, file_name="checkpoint.pth.tar", is_best=0):
30+
"""
31+
Checkpoint saver
32+
:param file_name: name of the checkpoint file
33+
:param is_best: boolean flag to indicate whether current checkpoint's metric is the best so far
34+
:return:
35+
"""
36+
raise NotImplementedError
37+
38+
def run(self):
39+
"""
40+
The main operator
41+
:return:
42+
"""
43+
raise NotImplementedError
44+
45+
def train(self):
46+
"""
47+
Main training loop
48+
:return:
49+
"""
50+
raise NotImplementedError
51+
52+
def train_one_epoch(self):
53+
"""
54+
One epoch of training
55+
:return:
56+
"""
57+
raise NotImplementedError
58+
59+
def validate(self):
60+
"""
61+
One cycle of model validation
62+
:return:
63+
"""
64+
raise NotImplementedError
65+
66+
def finalize(self):
67+
"""
68+
Finalizes all the operations of the 2 Main classes of the process, the operator and the data loader
69+
:return:
70+
"""
71+
raise NotImplementedError

data/.keep

Whitespace-only changes.

datasets/SentencePair.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import numpy as numpy
2+
import sentencepiece as spm
3+
4+
import torch
5+
6+
from torch.utils.data import DataLoader, TensorDataset, Dataset
7+
8+
class SentencePairDataLoader:
9+
def __init__(self, config):
10+
self.config = config
11+
12+
if config.data_mode == "corpus":
13+
self.f = open(self.config.data_folder, "r", encoding="utf-8", errors='ignore')
14+
self.tokenizer = self.config.tokenizer
15+
if self.tokenizer is 'bpe':
16+
prefix = bpe_model
17+
cmd = '--input={} --vocab_size={} --model_prefix={}'
18+
cmd = cmd.format(self.config.data_folder, vocab_size, prefix)
19+
try:
20+
spm.SentencePieceTrainer.Train(cmd)
21+
self.sp = spm.SenetencePieceProcesseor()
22+
sp.Load('{}.model'.format(prefix))
23+
except Exception:
24+
raise
25+
# bpe = self.sp.EncodeAsPieces(line) # to list
26+
else:
27+
raise NotImplementedError
28+
self.max_len = max_len
29+
self.batch_size = self.config.batch_size
30+
else:
31+
raise NotImplementedError
32+
33+
34+
35+

experiments/.keep

Whitespace-only changes.

graphs/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
import sys
3+
4+
path = os.path.dirname(os.path.abspath(__file__))
5+
6+
for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']:
7+
mod = __import__('.'.join([__name__, py]), fromlist=[py])
8+
classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)]
9+
for cls in classes:
10+
setattr(sys.modules[__name__], cls.__name__, cls)

graphs/losses/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
import sys
3+
4+
path = os.path.dirname(os.path.abspath(__file__))
5+
6+
for py in [f[:-3] for f in os.listdir(path) if f.endswith('.py') and f != '__init__.py']:
7+
mod = __import__('.'.join([__name__, py]), fromlist=[py])
8+
classes = [getattr(mod, x) for x in dir(mod) if isinstance(getattr(mod, x), type)]
9+
for cls in classes:
10+
setattr(sys.modules[__name__], cls.__name__, cls)

graphs/losses/bce.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
Binary Cross Entropy for DCGAN
3+
"""
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class BinaryCrossEntropy(nn.Module):
10+
def __init__(self):
11+
super().__init__()
12+
self.loss = nn.BCELoss()
13+
14+
def forward(self, logits, labels):
15+
loss = self.loss(logits, labels)
16+
return loss

graphs/losses/cross_entropy.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""
2+
Cross Entropy 2D for CondenseNet
3+
"""
4+
5+
import torch
6+
import torch.nn.functional as F
7+
import torch.nn as nn
8+
9+
import numpy as np
10+
11+
12+
class CrossEntropyLoss(nn.Module):
13+
def __init__(self, config=None):
14+
super(CrossEntropyLoss, self).__init__()
15+
if config == None:
16+
self.loss = nn.CrossEntropyLoss()
17+
else:
18+
class_weights = np.load(config.class_weights)
19+
self.loss = nn.CrossEntropyLoss(ignore_index=config.ignore_index,
20+
weight=torch.from_numpy(class_weights.astype(np.float32)),
21+
size_average=True, reduce=True)
22+
23+
def forward(self, inputs, targets):
24+
return self.loss(inputs, targets)

0 commit comments

Comments
 (0)