Skip to content

Commit 6ff2fbd

Browse files
authored
Merge pull request #125 from bwheelz36/bw_dev
some bug fixes in filter by time functionality
2 parents c586979 + 034a89f commit 6ff2fbd

File tree

2 files changed

+39
-44
lines changed

2 files changed

+39
-44
lines changed

ParticlePhaseSpace/DataLoaders.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,7 @@ def _import_data(self):
181181
raise Exception(f'failed to calculate momentums from topas data. Possible solution is to increase'
182182
f'the value of momentum_precision_factor, currently set to {momentum_precision_factor: 1.2e}'
183183
f'and failed data has value {relative_difference: 1.2e}')
184-
warnings.warn(f'{n_negative_locations: d} entries returned invalid pz values and were set to zero.'
185-
f'\nWe will now check that momentum and energy are consistent to within '
186-
f'{self._energy_consistency_check_cutoff: 1.4f} {self._units.energy.label}')
184+
warnings.warn(f'{n_negative_locations: d} entries returned invalid pz values and were set to zero.')
187185

188186
ParticleDir = [-1 if elem else 1 for elem in ParticleDir]
189187
self.data[self._columns['pz']] = np.multiply(np.sqrt(temp), ParticleDir)

ParticlePhaseSpace/_ParticlePhaseSpace.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -304,27 +304,32 @@ def particle_positions_hist_2D(self, beam_direction: str = 'z', quantity: str =
304304
ind = self._PS._ps_data['particle type [pdg_code]'] == particle
305305
ps_data = self._PS._ps_data.loc[ind]
306306
if beam_direction == 'x':
307-
loop_data = zip(ps_data[self._PS.columns['z']], ps_data[self._PS.columns['y']],
308-
ps_data[self._PS.columns['Ek']],
309-
ps_data['weight'])
310-
_xlabel = self._PS.columns['z']
311-
_ylabel = self._PS.columns['y']
312-
if beam_direction == 'y':
313-
loop_data = zip(ps_data[self._PS.columns['x']], ps_data[self._PS.columns['z']],
314-
ps_data[self._PS.columns['Ek']],
315-
ps_data['weight'])
316-
_xlabel = self._PS.columns['x']
317-
_ylabel = self._PS.columns['z']
318-
if beam_direction == 'z':
319-
loop_data = zip(ps_data[self._PS.columns['x']], ps_data[self._PS.columns['y']],
320-
ps_data[self._PS.columns['Ek']],
321-
ps_data['weight'])
322-
_xlabel = self._PS.columns['x']
323-
_ylabel = self._PS.columns['y']
307+
x_data = ps_data[self._PS.columns['y']]
308+
y_data = ps_data[self._PS.columns['z']]
309+
_x_label = self._PS.columns['y']
310+
_y_label = self._PS.columns['z']
311+
elif beam_direction == 'y':
312+
x_data = ps_data[self._PS.columns['x']]
313+
y_data = ps_data[self._PS.columns['z']]
314+
_x_label = self._PS.columns['x']
315+
_y_label = self._PS.columns['z']
316+
elif beam_direction == 'z':
317+
x_data = ps_data[self._PS.columns['x']]
318+
y_data = ps_data[self._PS.columns['y']]
319+
_x_label = self._PS.columns['x']
320+
_y_label = self._PS.columns['y']
321+
else:
322+
raise NotImplementedError('beam direction must be one of "x", "y" or "z"')
324323
if xlim is None:
325-
xlim = [ps_data[self._PS.columns['x']].min(), ps_data[self._PS.columns['x']].max()]
324+
if beam_direction == 'x':
325+
xlim = [ps_data[self._PS.columns['y']].min(), ps_data[self._PS.columns['y']].max()]
326+
elif beam_direction == 'y' or beam_direction == 'z':
327+
xlim = [ps_data[self._PS.columns['x']].min(), ps_data[self._PS.columns['x']].max()]
326328
if ylim is None:
327-
ylim = [ps_data[self._PS.columns['y']].min(), ps_data[self._PS.columns['y']].max()]
329+
if beam_direction == 'z':
330+
ylim = [ps_data[self._PS.columns['y']].min(), ps_data[self._PS.columns['y']].max()]
331+
elif beam_direction == 'x' or beam_direction == 'y':
332+
ylim = [ps_data[self._PS.columns['z']].min(), ps_data[self._PS.columns['z']].max()]
328333
if quantity == 'intensity':
329334
_title = f"n_particles intensity;\n{particle_cfg.particle_properties[particle]['name']}"
330335
_weight = ps_data['weight']
@@ -333,24 +338,23 @@ def particle_positions_hist_2D(self, beam_direction: str = 'z', quantity: str =
333338
_weight = np.multiply(ps_data['weight'], ps_data[self._PS.columns['Ek']])
334339
X = np.linspace(xlim[0], xlim[1], bins)
335340
Y = np.linspace(ylim[0], ylim[1], bins)
336-
h, xedges, yedges = np.histogram2d(ps_data[self._PS.columns['x']],
337-
ps_data[self._PS.columns['y']],
338-
bins=[X, Y], weights=_weight, )
341+
h, xedges, yedges = np.histogram2d(x_data,
342+
y_data,
343+
bins=[X, Y], weights=_weight)
344+
339345
if normalize:
340-
h = h * 100 / h.max()
341-
# _im1 = axs[0, n_axs].hist2d(ps_data[self._PS.columns['x']], ps_data[self._PS.columns['y']],
342-
# bins=[X,Y],
343-
# weights=_weight, norm=LogNorm(vmin=1, vmax=100),
344-
# cmap='inferno',
345-
# vmin=vmin, vmax=vmax)[3]
346+
try:
347+
h = h * 100 / h.max()
348+
except ValueError:
349+
pass
346350
_im1 = axs[0, n_axs].pcolormesh(xedges, yedges, h.T, cmap='inferno',
347351
norm=_scale, rasterized=False, vmin=vmin, vmax=vmax)
348352

349353
fig.colorbar(_im1, ax=axs[0, n_axs])
350354

351355
axs[0, n_axs].set_title(_title)
352-
axs[0, n_axs].set_xlabel(_xlabel, fontsize=_FigureSpecs.LabelFontSize)
353-
axs[0, n_axs].set_ylabel(_ylabel, fontsize=_FigureSpecs.LabelFontSize)
356+
axs[0, n_axs].set_xlabel(_x_label, fontsize=_FigureSpecs.LabelFontSize)
357+
axs[0, n_axs].set_ylabel(_y_label, fontsize=_FigureSpecs.LabelFontSize)
354358
axs[0, n_axs].set_aspect('equal')
355359
if grid:
356360
axs[0, n_axs].grid()
@@ -1376,7 +1380,7 @@ def assess_density_versus_r(self, Rvals=None, verbose: bool = True, beam_directi
13761380
print(density_data)
13771381
return density_data
13781382

1379-
def filter_by_time(self, t_start, t_finish):
1383+
def filter_by_time(self, t_start, t_finish, in_place: bool=False):
13801384
"""
13811385
Generates a new PhaseSpace which only contains particles inside t_start and t_finish (inclusive).
13821386
t_start and t_finish should be specfied in ps.
@@ -1389,16 +1393,9 @@ def filter_by_time(self, t_start, t_finish):
13891393
"""
13901394
ind = np.logical_and(self._ps_data[self.columns['time']] >= t_start,
13911395
self._ps_data[self.columns['time']] <= t_finish)
1392-
ps_data = self._ps_data[ind]
1393-
for col_name in ps_data.columns:
1394-
if not col_name in ps_cfg.get_required_column_names(self._units):
1395-
ps_data.drop(columns=col_name, inplace=True)
1396-
# create a new instance of _DataImportersBase based on particle_data
1397-
ps_data_loader = DataLoaders.Load_PandasData(ps_data, units=self._units)
1398-
new_instance = PhaseSpace(ps_data_loader)
1399-
print(f'Original data contains {len(self): d} particles')
1400-
print(f'Filtered data contains {len(new_instance): d} particles')
1401-
return new_instance
1396+
new_PS = self.filter_by_boolean_index(ind, in_place=in_place)
1397+
return new_PS
1398+
14021399

14031400
def filter_by_boolean_index(self, boolean_index, in_place: bool = False, split: bool = False):
14041401
"""

0 commit comments

Comments
 (0)