Skip to content

Commit f75ff7c

Browse files
Ci Zhangrabernat
authored andcommitted
read_header (#59)
* read in column names from header * make csv read backwards compatible * remove typo * fix read_header * check pandas version * conflicts * work backwards * fix_issue * simplification * requested changes * requested changes * requested changes * add test
1 parent 47c6c50 commit f75ff7c

File tree

6 files changed

+84
-41
lines changed

6 files changed

+84
-41
lines changed
-30 KB
Binary file not shown.
2.98 KB
Binary file not shown.
2.82 KB
Binary file not shown.

floater/test/test_generators.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -189,31 +189,34 @@ def test_pickling_with_land(fs_with_land, tmpdir):
189189

190190
def test_npart_to_2D_array():
191191
# floatsets
192-
lon = np.linspace(0, 8, 9, dtype=np.float32)
193-
lat = np.linspace(-4, 4, 9, dtype=np.float32)
194-
land_mask = np.zeros(81, dtype=bool)==False
192+
lon = np.arange(0, 9, dtype=np.float32)
193+
lat = np.arange(-4, 5, dtype=np.float32)
194+
land_mask = np.full(81, True, dtype=bool)
195195
land_mask.shape = (len(lat), len(lon))
196196
land_mask[:,0:2] = False
197197
model_grid = {'lon': lon, 'lat': lat, 'land_mask': land_mask}
198-
fs_none = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1.0, dy=1.0)
199-
fs_mask = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), dx=1.0, dy=1.0, model_grid=model_grid)
198+
fs_none = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5))
199+
fs_mask = gen.FloatSet(xlim=(0, 9), ylim=(-4, 5), model_grid=model_grid)
200+
fs_mask.get_rectmesh()
200201
# dataarray/dataset
201202
var_list = ['test_01', 'test_02', 'test_03']
202203
values_list_none = []
203204
values_list_mask = []
204205
data_vars_none = {}
205206
data_vars_mask = {}
207+
len_none = 81
208+
len_mask = list(fs_mask.ocean_bools).count(True)
206209
for var in var_list:
207-
values_none = np.random.random(81)
208-
values_none.shape = (1, 1, 81)
209-
values_mask = np.random.random(69)
210-
values_mask.shape = (1, 1, 69)
210+
values_none = np.random.random(len_none)
211+
values_none.shape = (1, 1, len_none)
212+
values_mask = np.random.random(len_mask)
213+
values_mask.shape = (1, 1, len_mask)
211214
values_list_none.append(values_none)
212215
values_list_mask.append(values_mask)
213216
data_vars_none.update({var: (['date', 'loc', 'npart'], values_none)})
214217
data_vars_mask.update({var: (['date', 'loc', 'npart'], values_mask)})
215-
npart_none = np.linspace(1, 81, 81, dtype=np.int32)
216-
npart_mask = np.linspace(1, 69, 69, dtype=np.int32)
218+
npart_none = np.arange(1, len_none+1, dtype=np.int32)
219+
npart_mask = np.arange(1, len_mask+1, dtype=np.int32)
217220
coords_none = {'date': (['date'], np.array([np.datetime64('2000-01-01')])),
218221
'loc': (['loc'], np.array(['New York'])),
219222
'npart': (['npart'], npart_none)}
@@ -229,7 +232,6 @@ def test_npart_to_2D_array():
229232
test_mask = (fs_mask, da1d_mask, ds1d_mask, values_list_mask)
230233
test_list = [test_none, test_mask]
231234
for fs, da1d, ds1d, values_list in test_list:
232-
fs.get_rectmesh()
233235
# method test
234236
da2d = fs.npart_to_2D_array(da1d)
235237
ds2d = fs.npart_to_2D_array(ds1d)

floater/test/test_utils.py

Lines changed: 62 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
-0.08639287948608398, 0.12957383692264557, -0.12062723934650421,
1717
0.0, 0.0, 2.6598372642183676e-06)
1818

19-
_TESTDATA_FILENAME_CSV = 'sample_mitgcm_float_trajectories_csv.tar.gz'
20-
_TMPDIR_SUBDIR_CSV = 'sample_mitgcm_data_csv'
19+
_TESTDATA_FILENAME_CSV_OLD = 'sample_mitgcm_float_trajectories_csv_old.tar.gz'
20+
_TMPDIR_SUBDIR_CSV_OLD = 'sample_mitgcm_data_csv_old'
2121

22+
_TESTDATA_FILENAME_CSV_NEW = 'sample_mitgcm_float_trajectories_csv_new.tar.gz'
23+
_TMPDIR_SUBDIR_CSV_NEW = 'sample_mitgcm_data_csv_new'
2224

2325
#@pytest.fixture()
2426
#def empty_output_dir(tmpdir):
@@ -41,12 +43,24 @@ def mitgcm_float_datadir(tmpdir_factory, request):
4143
return target_dir
4244

4345
@pytest.fixture(scope='module')
44-
def mitgcm_float_datadir_csv(tmpdir_factory, request):
46+
def mitgcm_float_datadir_csv_old(tmpdir_factory, request):
4547
filename = request.module.__file__
46-
datafile = os.path.join(os.path.dirname(filename), _TESTDATA_FILENAME_CSV)
48+
datafile = os.path.join(os.path.dirname(filename), _TESTDATA_FILENAME_CSV_OLD)
4749
if not os.path.exists(datafile):
4850
raise IOError('Could not find data file %s' % datafile)
49-
target_dir = str(tmpdir_factory.mktemp(_TMPDIR_SUBDIR_CSV))
51+
target_dir = str(tmpdir_factory.mktemp(_TMPDIR_SUBDIR_CSV_OLD))
52+
tar = tarfile.open(datafile)
53+
tar.extractall(target_dir)
54+
tar.close()
55+
return target_dir
56+
57+
@pytest.fixture(scope='module')
58+
def mitgcm_float_datadir_csv_new(tmpdir_factory, request):
59+
filename = request.module.__file__
60+
datafile = os.path.join(os.path.dirname(filename), _TESTDATA_FILENAME_CSV_NEW)
61+
if not os.path.exists(datafile):
62+
raise IOError('Could not find data file %s' % datafile)
63+
target_dir = str(tmpdir_factory.mktemp(_TMPDIR_SUBDIR_CSV_NEW))
5064
tar = tarfile.open(datafile)
5165
tar.extractall(target_dir)
5266
tar.close()
@@ -88,47 +102,71 @@ def test_floats_to_bcolz(tmpdir, mitgcm_float_datadir):
88102
for name, val in zip(_NAMES, _TESTVALS_FIRST):
89103
np.testing.assert_almost_equal(bc[0][name], val)
90104

91-
def test_floats_to_netcdf(tmpdir, mitgcm_float_datadir_csv):
105+
def test_floats_to_netcdf(tmpdir,
106+
mitgcm_float_datadir_csv_old,
107+
mitgcm_float_datadir_csv_new):
92108
"""Test that we can convert MITgcm float data into NetCDF format.
93109
"""
94110
import xarray as xr
95111
from floater.generators import FloatSet
96112

97-
input_dir = str(mitgcm_float_datadir_csv)
113+
input_dir_old = str(mitgcm_float_datadir_csv_old)
114+
input_dir_new = str(mitgcm_float_datadir_csv_new)
98115
output_dir = str(tmpdir)
99-
os.chdir(input_dir)
100-
fs = FloatSet(xlim=(-5, 5), ylim=(-2, 2), dx=1.0, dy=1.0)
116+
fs = FloatSet(xlim=(-5, 5), ylim=(-2, 2))
117+
118+
os.chdir(input_dir_old)
101119
fs.to_pickle('./fs.pkl')
120+
# least options
121+
utils.floats_to_netcdf(input_dir='./', output_fname='test_old')
122+
# most options
123+
utils.floats_to_netcdf(input_dir='./', output_fname='test_old',
124+
float_file_prefix='float_trajectories',
125+
ref_time='1993-01-01', pkl_path='./fs.pkl',
126+
output_dir=output_dir, output_prefix='prefix_test')
102127

128+
os.chdir(input_dir_new)
129+
fs.to_pickle('./fs.pkl')
103130
# least options
104-
utils.floats_to_netcdf(input_dir=input_dir, output_fname='test')
131+
utils.floats_to_netcdf(input_dir='./', output_fname='test_new')
105132
# most options
106-
utils.floats_to_netcdf(input_dir=input_dir, output_fname='test',
133+
utils.floats_to_netcdf(input_dir='./', output_fname='test_new',
107134
float_file_prefix='float_trajectories',
108135
ref_time='1993-01-01', pkl_path='./fs.pkl',
109136
output_dir=output_dir, output_prefix='prefix_test')
110137

111138
# filename prefix test
112-
os.chdir(input_dir)
113-
mfdl = xr.open_mfdataset('test_netcdf/float_trajectories.*.nc')
139+
os.chdir(input_dir_old)
140+
mfdol = xr.open_mfdataset('test_old_netcdf/float_trajectories.*.nc')
141+
os.chdir(input_dir_new)
142+
mfdnl = xr.open_mfdataset('test_new_netcdf/float_trajectories.*.nc')
114143
os.chdir(output_dir)
115-
mfdm = xr.open_mfdataset('test_netcdf/prefix_test.*.nc')
144+
mfdom = xr.open_mfdataset('test_old_netcdf/prefix_test.*.nc')
145+
mfdnm = xr.open_mfdataset('test_new_netcdf/prefix_test.*.nc')
116146

117147
# dimensions test
118148
dims = [{'time': 2, 'npart': 40}, {'time': 2, 'y0': 4, 'x0': 10}]
119-
assert mfdl.dims == dims[0]
120-
assert mfdm.dims == dims[1]
149+
assert mfdol.dims == dims[0]
150+
assert mfdom.dims == dims[1]
151+
assert mfdnl.dims == dims[0]
152+
assert mfdnm.dims == dims[1]
121153

122154
# variables and values test
123-
vars_values = [('x', 0.3237109375000000e+03), ('y', -0.7798437500000000e+02),
124-
('z', -0.4999999999999893e+00), ('u', -0.5346306607990328e-02),
125-
('v', -0.2787361934305595e-02), ('vort', 0.9160626946271506e-10)]
155+
vars_values = [('x', 0.1961093750000000E+03), ('y', -0.7848437500000000E+02),
156+
('z', -0.4999999999999893E+00), ('u', 0.3567512409555351E-04),
157+
('v', 0.1028276712547044E-03), ('vort', 0.0000000000000000E+00)]
158+
for var, value in vars_values:
159+
np.testing.assert_almost_equal(mfdol[var].values[0][0], value, 8)
160+
np.testing.assert_almost_equal(mfdom[var].values[0][0][0], value, 8)
161+
vars_values.append(('lavd', 0.0000000000000000E+00))
126162
for var, value in vars_values:
127-
np.testing.assert_almost_equal(mfdl[var].values[0][0], value, 8)
128-
np.testing.assert_almost_equal(mfdm[var].values[0][0][0], value, 8)
163+
np.testing.assert_almost_equal(mfdnl[var].values[0][0], value, 8)
164+
np.testing.assert_almost_equal(mfdnm[var].values[0][0][0], value, 8)
129165

130166
# times test
131-
times = [(0, 0, np.datetime64('1993-01-01', 'ns')), (1, 86400, np.datetime64('1993-01-02', 'ns'))]
167+
times = [(0, 0, np.datetime64('1993-01-01', 'ns')), (1, 2592000, np.datetime64('1993-01-31', 'ns'))]
132168
for i, sec, time in times:
133-
assert mfdl['time'][i].values == sec
134-
assert mfdm['time'][i].values == time
169+
assert mfdol['time'][i].values == sec
170+
assert mfdom['time'][i].values == time
171+
assert mfdnl['time'][i].values == sec
172+
assert mfdnm['time'][i].values == time

floater/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def floats_to_netcdf(input_dir, output_fname,
242242
Prefix of the transcoded NetCDF files
243243
"""
244244
import dask.dataframe as dd
245+
import pandas as pd
245246
import xarray as xr
246247
from floater.generators import FloatSet
247248
from glob import glob
@@ -251,14 +252,16 @@ def floats_to_netcdf(input_dir, output_fname,
251252

252253
match_pattern = float_file_prefix + '.*.csv'
253254
float_files = glob(os.path.join(input_dir, match_pattern))
255+
float_header = pd.read_csv(float_files[0], nrows=0).columns
254256
float_timesteps = sorted(list({int(float_file[-22:-12]) for float_file in float_files}))
257+
float_columns = ['npart', 'time', 'x', 'y', 'z', 'u', 'v', 'vort']
255258

256259
for float_timestep in tqdm(float_timesteps):
257260
input_path = os.path.join(input_dir, '%s.%010d.*.csv' % (float_file_prefix, float_timestep))
258-
df = dd.read_csv(input_path)
259-
if df.columns.values[0] != 'npart': # check if old format
260-
columns = ['npart', 'time', 'x', 'y', 'z', 'u', 'v', 'vort']
261-
df = dd.read_csv(input_path, names=columns, header=None)
261+
if float_header[0] != 'npart':
262+
df = dd.read_csv(input_path, names=float_columns, header=None)
263+
else:
264+
df = dd.read_csv(input_path)
262265
dfc = df.compute()
263266
dfcs = dfc.sort_values('npart')
264267
del_time = int(dfcs.time.values[0])
@@ -270,7 +273,7 @@ def floats_to_netcdf(input_dir, output_fname,
270273
time = np.array([np.int32(del_time)])
271274
npart = dfcs.npart.values.astype(np.int32)
272275
var_shape = (1, len(npart))
273-
var_names = dfcs.columns.values[2:]
276+
var_names = dfcs.columns[2:]
274277
data_vars = {var_name: (['time', 'npart'], dfcs[var_name].values.astype(np.float32).reshape(var_shape)) for var_name in var_names}
275278
ds = xr.Dataset(data_vars, coords={'time': time, 'npart': npart})
276279
if pkl_path is not None:

0 commit comments

Comments
 (0)