-
Notifications
You must be signed in to change notification settings - Fork 3
/
Data.py
123 lines (89 loc) · 4.24 KB
/
Data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
############################################################################################
#
# Project: Breast Cancer AI Research Project
# Repository: Tensorflow Quantum IDC Classifier 2020
# Project: Leveraging Quantum MNIST to detect Invasive Ductal Carcinoma
#
# Author: Adam Milton-Barker (AdamMiltonBarker.com)
# Contributors:
# Title: Data Helper Class
# Description: Data functions for the Leveraging Quantum MNIST to detect Invasive Ductal
# Carcinoma QNN (Quantum Neural Network).
# License: MIT License
# Last Modified: 2020-04-16
#
############################################################################################
import collections, os, random
import numpy as np
import tensorflow as tf
from random import seed as rseed
from numpy.random import seed as nseed
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.utils import shuffle as skshuffle
import matplotlib.pyplot as plt
from Classes.Helpers import Helpers
class Data():
""" Data Helper Class
Data functions for the Leveraging Quantum MNIST to detect Invasive Ductal Carcinoma
QNN (Quantum Neural Network).
"""
def __init__(self):
""" Initializes the class. """
self.Helpers = Helpers("Data", False)
self.dim = self.Helpers.confs["qnn"]["data"]["dim"]
self.dir_train = self.Helpers.confs["qnn"]["data"]["dir_train"]
self.seed = self.Helpers.confs["qnn"]["data"]["seed"]
nseed(self.seed)
rseed(self.seed)
self.data = []
self.labels = []
self.paths = []
self.Helpers.logger.info("Data Helper Class initialization complete.")
def get_paths_n_labels(self):
""" Stores data paths and labels as a list of tuples. """
for ddir in os.listdir(self.dir_train):
tpath = os.path.join(self.dir_train, ddir)
if os.path.isdir(tpath):
for filename in os.listdir(tpath):
if filename.lower().endswith(tuple(self.Helpers.confs["qnn"]["data"]["allowed"])):
self.paths.append((os.path.join(tpath, filename), ddir))
else:
continue
self.Helpers.logger.info("Data Paths: " + str(len(self.paths)))
def process_data(self):
""" Processes the data. """
for tdata in self.paths:
(image, label) = (tdata[0], tdata[1])
image_string = tf.io.read_file(image)
image_decoded = tf.image.decode_png(image_string, channels=3)
image = tf.cast(image_decoded, tf.float32)
image = tf.image.resize(image, [4, 4])
image = tf.image.rgb_to_grayscale(image)
self.data.append(image)
self.labels.append(label == "1")
self.shuffle()
self.convert_data()
self.encode_labels()
self.get_split()
def shuffle(self):
""" Shuffles the data and labels. """
self.data, self.labels = skshuffle(self.data, self.labels, random_state = self.seed)
self.Helpers.logger.info("Data shuffled")
def convert_data(self):
""" Converts the training data to a numpy array. """
self.data = np.array(self.data)
self.Helpers.logger.info("Converted data shape: " + str(self.data.shape))
def encode_labels(self):
""" One Hot Encodes the labels. """
self.labels = np.array(self.labels)
self.Helpers.logger.info("Encoded labels shape: " + str(self.labels.shape))
def get_split(self):
""" Splits the data and labels creating training and validation datasets. """
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(
self.data, self.labels, test_size=0.255, random_state = self.seed)
self.X_train, self.X_test = self.X_train[..., np.newaxis]/255.0, self.X_test[..., np.newaxis]/255.0
self.Helpers.logger.info("Training data: " + str(self.X_train.shape))
self.Helpers.logger.info("Training labels: " + str(self.y_train.shape))
self.Helpers.logger.info("Validation data: " + str(self.X_test.shape))
self.Helpers.logger.info("Validation labels: " + str(self.y_test.shape))