Skip to content

Commit b97875e

Browse files
author
Bruno Alves
committed
add macro that counts number of elements in a dataset
1 parent 704e8ed commit b97875e

File tree

2 files changed

+34
-15
lines changed

2 files changed

+34
-15
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import tensorflow as tf
2+
from src.data.tfrecords import read_spectra_data as read_data
3+
4+
batch_size = 1
5+
files_path = '/fred/oz012/Bruno/data/spectra/qso_zWarning/'
6+
dataset = read_data(files_path+'spectra2_.tfrecord', 3500) #reads all shards: spectra2_0.tfrecord, ...
7+
dataset = dataset.repeat(1).batch(batch_size)
8+
9+
iterator = dataset.make_initializable_iterator()
10+
next_element = iterator.get_next()
11+
12+
counter = 0
13+
with tf.Session() as sess:
14+
sess.run(iterator.initializer)
15+
while True:
16+
try:
17+
inputs, *params = sess.run(next_element)
18+
counter += 1
19+
if counter%10000==0:
20+
print(counter)
21+
except tf.errors.OutOfRangeError:
22+
print('Number of elements inside the dataset: ', counter)
23+
break
24+
Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,30 @@
11
from astropy.io import fits
22
import tensorflow as tf
33
import numpy as np
4-
from src.data.data import read_spectra_data as read_data
4+
from src.data.tfrecords import read_spectra_data as read_data
55
from src.utilities import PlotGenSamples
66

77
def _plot(data, params_data, name, n=5):
88
p = PlotGenSamples(nrows=n, ncols=1)
99
p.plot_spectra(data[:n], params_data[0][:n], name)
1010

11-
batch_size = 512
12-
dataset_size = 255483
13-
files_path = '/fred/oz012/Bruno/data/spectra/boss/cmass/'
14-
dataset = read_data(files_path+'spectra.tfrecord', 3500)
11+
batch_size = 1
12+
dataset_size = 22791
13+
files_path = '/fred/oz012/Bruno/data/spectra/qso_zWarning/'
14+
dataset = read_data(files_path+'spectra2_.tfrecord', 3500)
1515
dataset = dataset.repeat(1).batch(batch_size)
1616
nbatches = int(np.ceil(dataset_size/batch_size))
1717

1818
iterator = dataset.make_initializable_iterator()
1919
next_element = iterator.get_next()
2020

21+
counter = 0
2122
with tf.Session() as sess:
2223
sess.run(iterator.initializer)
23-
for item in range(5):
24+
for item in range(2):
2425
inputs, *params = sess.run(next_element)
26+
counter += 1
27+
if counter%10000==0:
28+
print(counter)
2529
_plot(inputs, params, name='tfrecord_data_'+str(item))
2630

27-
"""
28-
local_path = '/fred/oz012/Bruno/data/spectra/boss/loz/7415/spec-7415-57097-0197.fits'
29-
with fits.open(local_path) as hdu:
30-
flux_ = hdu[1].data['flux'].astype(np.float32)
31-
lam_ = np.power( 10, hdu[1].data['loglam'] ).astype(np.float32)
32-
33-
print(flux_)
34-
print(min(flux_))
35-
"""

0 commit comments

Comments
 (0)